# External Validation of the BAN-ADHF Diuretic Resistance Score

## Reproducible Analysis Notebook

This notebook provides complete, reproducible code for validating the BAN-ADHF score in critically ill patients with acute decompensated heart failure using the MIMIC-IV database.

---

### Study Overview

**Background:** The BAN-ADHF score predicts diuretic efficiency in hospitalized heart failure patients, but derivation and validation cohorts excluded hemodynamically unstable patients. Whether this score maintains predictive validity in critically ill intensive care unit populations, where diuretic resistance is more prevalent and consequential, remains unknown.

**Methods:** We performed a retrospective cohort study using the MIMIC-IV database (2008-2022). We included 1,505 adult ICU patients with acute decompensated heart failure receiving intravenous diuretics. Co-primary outcomes were 24-hour and 72-hour diuretic efficiency (mL urine output per mg IV furosemide equivalent). We assessed discrimination using Spearman correlation, C-index, and AUROC for lowest efficiency quintile. Patients were stratified into low (≤7), moderate (8-12), and high (≥13) risk categories based on data-driven cutoffs.

**Key Results:**
- Among 1,019 patients with calculable 24-hour diuretic efficiency, the BAN-ADHF score demonstrated strong inverse correlation (Spearman ρ = -0.518, 95% CI: -0.560 to -0.473; p<0.001)
- Discrimination for the lowest efficiency quintile was good (AUROC 0.780, 95% CI: 0.743-0.812)
- Median efficiency decreased across risk categories: 47.4 mL/mg (low-risk), 29.0 mL/mg (moderate-risk), and 11.3 mL/mg (high-risk), a 4.2-fold difference (p<0.001)

---

### Prerequisites

Before running this notebook:

1. **MIMIC-IV Access**: Complete PhysioNet credentialing and obtain access to MIMIC-IV v3.1
2. **BigQuery Setup**: Link your Google Cloud project to PhysioNet BigQuery datasets
3. **SQL Cohort**: Run the SQL queries in `/sql/` folder to create the `final_cohort` table
4. **Project ID**: Replace `YOUR-PROJECT-ID` with your Google Cloud project ID

---

### Analysis Sections

| Section | Description |
|---------|-------------|
| 1 | Setup & Data Loading |
| 2 | Risk Category Derivation |
| 3 | Baseline Characteristics (Table 1) |
| 4 | 24-Hour Diuretic Efficiency (Co-Primary Outcome) |
| 5 | 72-Hour Diuretic Efficiency (Co-Primary Outcome) |
| 6 | Diuretic Resistance Analysis |
| 7 | In-Hospital Mortality (Exploratory) |
| 8 | Subgroup Analyses |
| 9 | Sensitivity Analyses |
| 10 | Secondary Outcomes |
| 11 | Output Generation |

In [None]:
#==========================================================================
# BAN-ADHF ICU VALIDATION STUDY - CLEAN ANALYSIS NOTEBOOK
# External Validation of BAN-ADHF Score in Critically Ill ADHF Patients
#==========================================================================
# Data Source: MIMIC-IV v3.1 (2008-2022)
# Reporting Standard: TRIPOD+AI 2024
# Date: December 2024
#==========================================================================
# SECTION 1: SETUP & DATA LOADING
# Package Installation and Core Imports
#==========================================================================
# PURPOSE: Install required packages and import core libraries
# OUTPUT: Confirmation of successful package installation
#==========================================================================

# 1.1 Install required packages (run once per session)
!pip install tableone dcurves lifelines --quiet

# 1.2 Import core libraries
import pandas as pd
import numpy as np
from scipy import stats
from scipy.stats import spearmanr, kruskal, chi2_contingency
from sklearn.metrics import roc_auc_score, roc_curve
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# 1.3 Set display options for cleaner output
pd.set_option('display.max_columns', 50)
pd.set_option('display.float_format', '{:.3f}'.format)

# 1.4 Confirm successful setup
print("="*60)
print("ENVIRONMENT SETUP COMPLETE")
print("="*60)
print("✓ tableone      - Table 1 generation")
print("✓ dcurves       - Decision curve analysis")
print("✓ lifelines     - C-index calculation")
print("✓ scipy         - Statistical tests")
print("✓ sklearn       - ROC/AUROC analysis")
print("✓ pandas/numpy  - Data manipulation")
print("✓ matplotlib    - Visualization")
print("="*60)

In [None]:
#==========================================================================
# SECTION 1: SETUP & DATA LOADING
# Cell 2: BigQuery Authentication and Data Loading
#==========================================================================
# PURPOSE: Connect to BigQuery and load the final BAN-ADHF cohort
# OUTPUT: Loaded dataframe with cohort summary statistics
#==========================================================================

from google.colab import auth
from google.cloud import bigquery

# 2.1 Authenticate with Google Cloud
auth.authenticate_user()
print("✓ Google Cloud authentication successful")

# 2.2 Initialize BigQuery client
PROJECT_ID = 'YOUR-PROJECT-ID'
client = bigquery.Client(project=PROJECT_ID)
print(f"✓ BigQuery client initialized for project: {PROJECT_ID}")

# 2.3 Load the final cohort from BigQuery
query = """
SELECT *
FROM `YOUR-PROJECT-ID.ban_adhf.final_cohort`
"""

df = client.query(query).to_dataframe()
print(f"✓ Data loaded: {len(df):,} rows, {len(df.columns)} columns")

# 2.4 Cohort summary for verification
print("\n" + "="*60)
print("COHORT SUMMARY")
print("="*60)
print(f"Total hospital admissions:   {len(df):,}")
print(f"Unique patients:             {df['subject_id'].nunique():,}")
print(f"In-hospital deaths:          {df['hospital_expire_flag'].sum():,} ({100*df['hospital_expire_flag'].mean():.1f}%)")
print(f"Median BAN-ADHF score:       {df['ban_adhf_total_score'].median():.0f}")
print(f"Mean BAN-ADHF score:         {df['ban_adhf_total_score'].mean():.1f}")
print(f"ICU stay ≥24h:               {df['icu_stay_ge_24h'].sum():,}")
print(f"ICU stay ≥72h:               {df['icu_stay_ge_72h'].sum():,}")
print("="*60)

In [None]:
#==========================================================================
# SECTION 1: SETUP & DATA LOADING
# Cell 3: Data Validation and Column Verification
#==========================================================================
# PURPOSE: Verify all required columns exist before analysis
# OUTPUT: Confirmation of data structure and variable availability
#==========================================================================

print("="*60)
print("DATA VALIDATION - COLUMN VERIFICATION")
print("="*60)

# 3.1 Check BAN-ADHF score components (8 variables)
print("\n1. BAN-ADHF SCORE COMPONENTS (points_* columns)")
print("-"*40)
ban_adhf_components = {
    'points_creatinine': 'Creatinine (0/2/4 pts)',
    'points_bun': 'BUN (0/2/3 pts)',
    'points_ntprobnp': 'NT-proBNP (0/2/4 pts)',
    'points_dbp': 'Diastolic BP (0/2 pts)',
    'points_home_diuretic': 'Home diuretic dose (0/3/6 pts)',
    'points_afib': 'Atrial fibrillation (0/2 pts)',
    'points_htn': 'Hypertension (0/2 pts)',
    'points_prior_hf': 'Prior HF hospitalization (0/3 pts)'
}

component_sum = 0
for col, description in ban_adhf_components.items():
    if col in df.columns:
        print(f"  ✓ {col}: {description}")
        component_sum += 1
    else:
        print(f"  ✗ {col} - MISSING")

print(f"  → {component_sum}/8 components present")

# 3.2 Verify total score calculation
print("\n2. BAN-ADHF TOTAL SCORE VERIFICATION")
print("-"*40)
print(f"  ban_adhf_total_score range: {df['ban_adhf_total_score'].min():.0f} - {df['ban_adhf_total_score'].max():.0f}")
print(f"  Expected range: 0-26")
print(f"  Mean (SD): {df['ban_adhf_total_score'].mean():.1f} ({df['ban_adhf_total_score'].std():.1f})")
print(f"  Median (IQR): {df['ban_adhf_total_score'].median():.0f} ({df['ban_adhf_total_score'].quantile(0.25):.0f}-{df['ban_adhf_total_score'].quantile(0.75):.0f})")

# Verify component sum equals total (spot check)
if all(col in df.columns for col in ban_adhf_components.keys()):
    component_cols = list(ban_adhf_components.keys())
    df['_check_sum'] = df[component_cols].sum(axis=1)
    match_pct = (df['_check_sum'] == df['ban_adhf_total_score']).mean() * 100
    print(f"  Component sum matches total: {match_pct:.1f}%")
    df.drop('_check_sum', axis=1, inplace=True)

# 3.3 Check outcome variables
print("\n3. OUTCOME VARIABLES")
print("-"*40)
outcome_vars = {
    'hospital_expire_flag': 'In-hospital mortality (co-primary)',
    'diuretic_efficiency_24h': '24h diuretic efficiency (co-primary)',
    'diuretic_efficiency_72h': '72h diuretic efficiency (co-primary)',
    'urine_output_24h_ml': '24h urine output (for binary DR)',
    'diuretic_resistance': 'Binary DR flag (pre-calculated)'
}

for col, description in outcome_vars.items():
    if col in df.columns:
        non_null = df[col].notna().sum()
        print(f"  ✓ {col}")
        print(f"      {description}: {non_null:,} non-null")
    else:
        print(f"  ✗ {col} - MISSING")

# 3.4 Check subgroup variables
print("\n4. SUBGROUP VARIABLES (8 pre-specified)")
print("-"*40)
subgroup_vars = {
    'age_65_or_older': 'Age ≥65 vs <65',
    'gender': 'Male vs Female',
    'hx_diabetes': 'Diabetes (yes/no)',
    'chronic_advanced_ckd': 'Chronic kidney disease',
    'hx_atrial_fibrillation': 'Atrial fibrillation',
    'on_home_diuretics': 'On home diuretics',
    'cardiogenic_shock': 'Cardiogenic shock',
    'hf_phenotype': 'HF phenotype (HFrEF/HFpEF/HFmrEF)'
}

for col, description in subgroup_vars.items():
    if col in df.columns:
        print(f"  ✓ {col}: {description}")
    else:
        print(f"  ✗ {col} - MISSING")

# 3.5 Check ICU duration and diuretic dose variables
print("\n5. ICU DURATION & DIURETIC DOSE VARIABLES")
print("-"*40)
icu_vars = {
    'icu_stay_ge_24h': f"ICU ≥24h: {df['icu_stay_ge_24h'].sum():,} patients",
    'icu_stay_ge_72h': f"ICU ≥72h: {df['icu_stay_ge_72h'].sum():,} patients",
    'iv_diuretic_dose_24h_mg': '24h IV diuretic dose',
    'iv_diuretic_dose_72h_mg': '72h IV diuretic dose'
}

for col, description in icu_vars.items():
    if col in df.columns:
        print(f"  ✓ {col}: {description}")
    else:
        print(f"  ✗ {col} - MISSING")

# 3.6 Summary and score distribution for cutoff derivation
print("\n" + "="*60)
print("VALIDATION COMPLETE")
print("="*60)
print("\nBAN-ADHF Score Distribution (for cutoff derivation):")
print("-"*40)
percentiles = [10, 20, 25, 33, 50, 67, 75, 80, 90]
for p in percentiles:
    val = df['ban_adhf_total_score'].quantile(p/100)
    print(f"  {p:3}th percentile: {val:.0f}")

print("\n" + "="*60)
print("NOTE: Will derive risk category cutoffs from scratch using")
print("5 statistical methods in the next cell (ignoring any existing")
print("risk category column that may use different thresholds)")
print("="*60)

In [None]:
#==========================================================================
# SECTION 2: RISK CATEGORY DERIVATION
# Cell 4: BAN-ADHF Score Distribution Visualization
#==========================================================================
# PURPOSE: Visualize the score distribution before deriving cutoffs
# OUTPUT: Histogram with percentile markers
#==========================================================================

import matplotlib.pyplot as plt

print("="*60)
print("BAN-ADHF SCORE DISTRIBUTION (Full Cohort, N=1,505)")
print("="*60)

# Calculate distribution statistics
scores = df['ban_adhf_total_score']
print(f"\nDescriptive Statistics:")
print(f"  N:          {len(scores):,}")
print(f"  Mean:       {scores.mean():.1f}")
print(f"  SD:         {scores.std():.1f}")
print(f"  Median:     {scores.median():.0f}")
print(f"  IQR:        {scores.quantile(0.25):.0f} - {scores.quantile(0.75):.0f}")
print(f"  Range:      {scores.min():.0f} - {scores.max():.0f}")

# Calculate percentiles using pandas quantile method
print(f"\nPercentiles (calculated from N=1,505 using pandas .quantile()):")
percentile_values = {}
for p in [25, 33, 50, 67, 75]:
    val = scores.quantile(p/100)
    percentile_values[p] = val
    print(f"  {p}th percentile: {val:.0f}")

# Create histogram
fig, ax = plt.subplots(figsize=(10, 6))

# Plot histogram
counts, bins, patches = ax.hist(scores, bins=range(0, 28),
                                 edgecolor='black', alpha=0.7, color='steelblue')

# Add vertical lines for key percentiles
ax.axvline(x=7, color='green', linestyle='--', linewidth=2, label='25th %ile (score=7)')
ax.axvline(x=13, color='red', linestyle='--', linewidth=2, label='67th %ile (score=13)')

# Add labels
ax.set_xlabel('BAN-ADHF Total Score', fontsize=12)
ax.set_ylabel('Number of Patients', fontsize=12)
ax.set_title('Distribution of BAN-ADHF Scores in ICU ADHF Cohort (N=1,505)', fontsize=14)
ax.legend(loc='upper right')
ax.set_xlim(0, 27)

# Add text annotation
textstr = f'Mean: {scores.mean():.1f}\nMedian: {scores.median():.0f}\nSD: {scores.std():.1f}'
ax.text(0.95, 0.95, textstr, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', horizontalalignment='right',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

# Show frequency table
print("\n" + "="*60)
print("SCORE FREQUENCY TABLE")
print("="*60)
score_counts = scores.value_counts().sort_index()
print(f"\n{'Score':<8} {'N':<8} {'%':<8} {'Cumulative %':<12}")
print("-"*36)
cumsum = 0
for score in range(int(scores.min()), int(scores.max()) + 1):
    n = score_counts.get(score, 0)
    pct = 100 * n / len(scores)
    cumsum += pct
    print(f"{score:<8} {n:<8} {pct:<8.1f} {cumsum:<12.1f}")

print("\n→ Next: Apply 5 statistical methods to derive optimal risk cutoffs")

In [None]:
#==========================================================================
# SECTION 2: RISK CATEGORY DERIVATION
# Cell 5: Optimal Cutoff Determination Using 5 Statistical Methods
#==========================================================================
# PURPOSE: Derive risk category cutoffs from our data using multiple methods
# BACKGROUND: Original Segar 2024 derivation did NOT define categorical
#             cutoffs - the score was presented as continuous (0-26).
#             We derive cutoffs prioritizing diuretic efficiency (the
#             score's intended outcome) over mortality.
# OUTPUT: Optimal cutoffs from each method for comparison
#==========================================================================

from sklearn.metrics import roc_curve
from sklearn.tree import DecisionTreeClassifier
from scipy.stats import chi2_contingency, kruskal
import numpy as np

print("="*70)
print("RISK CATEGORY CUTOFF DERIVATION")
print("="*70)
print("\nRationale: The original BAN-ADHF derivation (Segar 2024) did not")
print("define categorical risk cutoffs. We derive optimal cutoffs using")
print("5 complementary statistical methods, prioritizing diuretic efficiency")
print("(the score's intended purpose) over mortality.\n")

# Prepare the 24h efficiency cohort for analysis
# Criteria: ICU ≥24h AND valid efficiency data (IV dose > 0)
df_eff = df[(df['icu_stay_ge_24h'] == 1) &
            (df['diuretic_efficiency_24h'].notna()) &
            (df['diuretic_efficiency_24h'] > 0)].copy()

print(f"Analysis cohort: N = {len(df_eff):,} (ICU ≥24h with valid 24h efficiency)")

# Define lowest quintile (bottom 20%) as binary outcome
quintile_threshold = df_eff['diuretic_efficiency_24h'].quantile(0.20)
df_eff['lowest_quintile'] = (df_eff['diuretic_efficiency_24h'] <= quintile_threshold).astype(int)

print(f"Lowest quintile threshold: ≤{quintile_threshold:.1f} mL/mg")
print(f"Patients in lowest quintile: {df_eff['lowest_quintile'].sum()} ({100*df_eff['lowest_quintile'].mean():.1f}%)")

#------------------------------------------------------------------------------
# METHOD 1: YOUDEN'S INDEX (Single Cutoff Optimization)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("METHOD 1: YOUDEN'S INDEX")
print("="*70)
print("Purpose: Find single cutoff that maximizes (Sensitivity + Specificity - 1)")

# For lowest quintile of efficiency (binary outcome)
fpr, tpr, thresholds = roc_curve(df_eff['lowest_quintile'],
                                  df_eff['ban_adhf_total_score'])
youden_j = tpr - fpr
optimal_idx = np.argmax(youden_j)
youden_threshold = thresholds[optimal_idx]
youden_sensitivity = tpr[optimal_idx]
youden_specificity = 1 - fpr[optimal_idx]

print(f"\nOutcome: Lowest quintile of 24h diuretic efficiency")
print(f"Optimal cutoff: ≥{youden_threshold:.0f}")
print(f"Sensitivity: {youden_sensitivity:.3f}")
print(f"Specificity: {youden_specificity:.3f}")
print(f"Youden's J: {youden_j[optimal_idx]:.3f}")

# Also check for mortality
fpr_m, tpr_m, thresholds_m = roc_curve(df['hospital_expire_flag'],
                                        df['ban_adhf_total_score'])
youden_j_m = tpr_m - fpr_m
optimal_idx_m = np.argmax(youden_j_m)
youden_threshold_m = thresholds_m[optimal_idx_m]

print(f"\n[Secondary] Mortality optimal cutoff: ≥{youden_threshold_m:.0f}")

#------------------------------------------------------------------------------
# METHOD 2: RECURSIVE PARTITIONING (CART Decision Tree)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("METHOD 2: RECURSIVE PARTITIONING (CART Decision Tree)")
print("="*70)
print("Purpose: Data-driven identification of natural split points")

X = df_eff[['ban_adhf_total_score']].values
y = df_eff['lowest_quintile'].values

tree = DecisionTreeClassifier(max_depth=2, min_samples_leaf=50, random_state=42)
tree.fit(X, y)

# Extract thresholds from tree structure
def get_tree_thresholds(tree_model):
    tree_ = tree_model.tree_
    thresholds = []
    def recurse(node):
        if tree_.feature[node] != -2:  # Not a leaf
            thresholds.append(tree_.threshold[node])
            recurse(tree_.children_left[node])
            recurse(tree_.children_right[node])
    recurse(0)
    return sorted(thresholds)

tree_thresholds = get_tree_thresholds(tree)
print(f"\nOutcome: Lowest quintile of 24h diuretic efficiency")
print(f"Decision tree splits at: {[f'{t:.1f}' for t in tree_thresholds]}")

#------------------------------------------------------------------------------
# METHOD 3: MAXIMUM CHI-SQUARE (Two-Cutoff Optimization)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("METHOD 3: MAXIMUM CHI-SQUARE (Two-Cutoff Optimization)")
print("="*70)
print("Purpose: Find two cutoffs that maximize chi-square for binary outcome")

def find_optimal_cutoffs_chi2(data, score_col, outcome_col, score_range=(4, 20)):
    """Find two cutoffs that maximize chi-square statistic"""
    best_chi2 = 0
    best_cutoffs = (7, 13)
    best_p = 1.0

    for low_cut in range(score_range[0], score_range[1] - 2):
        for high_cut in range(low_cut + 2, score_range[1]):
            cats = pd.cut(data[score_col],
                         bins=[-np.inf, low_cut, high_cut, np.inf],
                         labels=['Low', 'Moderate', 'High'])
            contingency = pd.crosstab(cats, data[outcome_col])
            if contingency.shape == (3, 2):
                chi2, p, _, _ = chi2_contingency(contingency)
                if chi2 > best_chi2:
                    best_chi2 = chi2
                    best_cutoffs = (low_cut, high_cut)
                    best_p = p

    return best_cutoffs, best_chi2, best_p

# For lowest quintile
chi2_cutoffs, chi2_stat, chi2_p = find_optimal_cutoffs_chi2(
    df_eff, 'ban_adhf_total_score', 'lowest_quintile'
)
print(f"\nOutcome: Lowest quintile of 24h diuretic efficiency")
print(f"Optimal cutoffs: ≤{chi2_cutoffs[0]} / {chi2_cutoffs[0]+1}-{chi2_cutoffs[1]} / ≥{chi2_cutoffs[1]+1}")
print(f"Chi-square: {chi2_stat:.2f}")
print(f"P-value: {chi2_p:.2e}")

# For mortality (full cohort)
chi2_cutoffs_m, chi2_stat_m, chi2_p_m = find_optimal_cutoffs_chi2(
    df, 'ban_adhf_total_score', 'hospital_expire_flag'
)
print(f"\n[Secondary] Mortality optimal cutoffs: ≤{chi2_cutoffs_m[0]} / ≥{chi2_cutoffs_m[1]+1}")
print(f"Chi-square: {chi2_stat_m:.2f}")

#------------------------------------------------------------------------------
# METHOD 4: MAXIMUM KRUSKAL-WALLIS H (Continuous Outcome)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("METHOD 4: MAXIMUM KRUSKAL-WALLIS H (Continuous Outcome)")
print("="*70)
print("Purpose: Find two cutoffs that maximize group separation for")
print("         continuous diuretic efficiency")

def find_optimal_cutoffs_kw(data, score_col, outcome_col, score_range=(4, 20)):
    """Find two cutoffs that maximize Kruskal-Wallis H for continuous outcome"""
    best_h = 0
    best_cutoffs = (7, 13)
    best_p = 1.0

    for low_cut in range(score_range[0], score_range[1] - 2):
        for high_cut in range(low_cut + 2, score_range[1]):
            low = data[data[score_col] <= low_cut][outcome_col].values
            mod = data[(data[score_col] > low_cut) & (data[score_col] <= high_cut)][outcome_col].values
            high = data[data[score_col] > high_cut][outcome_col].values

            if len(low) > 10 and len(mod) > 10 and len(high) > 10:
                h_stat, p = kruskal(low, mod, high)
                if h_stat > best_h:
                    best_h = h_stat
                    best_cutoffs = (low_cut, high_cut)
                    best_p = p

    return best_cutoffs, best_h, best_p

kw_cutoffs, kw_h, kw_p = find_optimal_cutoffs_kw(
    df_eff, 'ban_adhf_total_score', 'diuretic_efficiency_24h'
)
print(f"\nOutcome: 24h diuretic efficiency (continuous)")
print(f"Optimal cutoffs: ≤{kw_cutoffs[0]} / {kw_cutoffs[0]+1}-{kw_cutoffs[1]} / ≥{kw_cutoffs[1]+1}")
print(f"Kruskal-Wallis H: {kw_h:.2f}")
print(f"P-value: {kw_p:.2e}")

#------------------------------------------------------------------------------
# METHOD 5: DISTRIBUTION-BASED (Tertiles/Quartiles)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("METHOD 5: DISTRIBUTION-BASED (Tertiles/Quartiles)")
print("="*70)
print("Purpose: Define cutoffs based on cohort score distribution")

# Tertiles (33rd and 67th percentiles)
tertile_low = df['ban_adhf_total_score'].quantile(0.333)
tertile_high = df['ban_adhf_total_score'].quantile(0.667)

# Quartiles (25th and 75th percentiles)
q1 = df['ban_adhf_total_score'].quantile(0.25)
q3 = df['ban_adhf_total_score'].quantile(0.75)

print(f"\nFull cohort (N=1,505):")
print(f"  Tertiles (33rd/67th %ile): {tertile_low:.0f} / {tertile_high:.0f}")
print(f"    → Categories: ≤{tertile_low:.0f}, {tertile_low:.0f}-{tertile_high:.0f}, ≥{tertile_high:.0f}")
print(f"  Quartiles (25th/75th %ile): {q1:.0f} / {q3:.0f}")
print(f"    → Q1 boundary (low risk): ≤{q1:.0f}")
print(f"    → Q4 boundary (high risk): ≥{q3:.0f}")

#------------------------------------------------------------------------------
# SUMMARY TABLE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUMMARY: OPTIMAL CUTOFFS BY METHOD")
print("="*70)

# Format tree thresholds for display
tree_low = f"{tree_thresholds[0]:.1f}" if len(tree_thresholds) > 0 else "—"
tree_high = f"{tree_thresholds[1]:.1f}" if len(tree_thresholds) > 1 else "—"

print(f"\n{'Method':<35} {'Low/Mod':<12} {'Mod/High':<12} {'Primary Outcome':<20}")
print("-"*70)
print(f"{'1. Youden J (efficiency)':<35} {'—':<12} {'≥' + str(int(youden_threshold)):<12} {'Efficiency quintile':<20}")
print(f"{'2. Decision Tree (efficiency)':<35} {tree_low:<12} {tree_high:<12} {'Efficiency quintile':<20}")
print(f"{'3. Max Chi-square (efficiency)':<35} {'≤' + str(chi2_cutoffs[0]):<12} {'≥' + str(chi2_cutoffs[1]+1):<12} {'Efficiency quintile':<20}")
print(f"{'4. Max Kruskal-Wallis':<35} {'≤' + str(kw_cutoffs[0]):<12} {'≥' + str(kw_cutoffs[1]+1):<12} {'Continuous efficiency':<20}")
print(f"{'5a. Tertiles':<35} {'≤' + str(int(tertile_low)):<12} {'≥' + str(int(tertile_high)):<12} {'Distribution':<20}")
print(f"{'5b. Quartiles (Q1/Q4)':<35} {'≤' + str(int(q1)):<12} {'≥' + str(int(q3)):<12} {'Distribution':<20}")
print("-"*70)
print(f"{'[Literature] Alfonso 2025':<35} {'<7':<12} {'≥13':<12} {'Conceptual framework':<20}")
print(f"{'[Literature] Pandey 2025':<35} {'—':<12} {'>12 (≥13)':<12} {'Youden binary':<20}")
print(f"{'[Literature] Mauch 2025':<35} {'Q1: 0-6':<12} {'Q4: ≥13':<12} {'Quartiles':<20}")

print("\n" + "="*70)
print("KEY OBSERVATION:")
print("="*70)
print("For diuretic efficiency (primary outcome), optimal cutoffs converge on:")
print(f"  • Low/Moderate boundary: ≤7 (supported by Kruskal-Wallis, Quartiles)")
print(f"  • Moderate/High boundary: ≥13-14 (supported by Youden, Tree, Literature)")
print("\n→ Next: Head-to-head comparison of candidate frameworks (≥13 vs ≥14)")

In [None]:
#==========================================================================
# SECTION 2: RISK CATEGORY DERIVATION
# Cell 6: Head-to-Head Framework Comparison and Final Selection
#==========================================================================
# PURPOSE: Compare candidate frameworks using Cell 5 statistical results
# NOTE: BAN-ADHF was designed for diuretic efficiency, NOT mortality
# APPROACH: Cutoffs derived from efficiency outcomes; mortality is exploratory
#==========================================================================

from scipy.stats import kruskal, chi2_contingency

print("="*70)
print("FRAMEWORK COMPARISON AND FINAL CUTOFF SELECTION")
print("="*70)
print("\nBAN-ADHF was designed to predict DIURETIC EFFICIENCY.")
print("Cutoffs are derived from efficiency outcomes; mortality is exploratory.")

#------------------------------------------------------------------------------
# RECAP: CELL 5 STATISTICAL RESULTS
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("SUMMARY OF CELL 5 STATISTICAL METHODS (N=1,019 efficiency cohort)")
print("-"*70)
print("""
Method                          | Low/Mod Cutoff | Mod/High Cutoff | Primary Outcome
--------------------------------|----------------|-----------------|------------------
1. Youden's Index               | —              | ≥13             | Efficiency quintile
2. Decision Tree (CART)         | 5.5            | 13.5            | Efficiency quintile
3. Max Chi-square               | ≤12            | ≥17             | Efficiency quintile
4. Max Kruskal-Wallis (H=260.0) | ≤7             | ≥14             | Continuous efficiency
5a. Tertiles (33rd/67th %ile)   | ≤8             | ≥13             | Distribution
5b. Quartiles (25th/75th %ile)  | ≤7             | ≥15             | Distribution
""")

print("KEY CONVERGENCE POINTS:")
print("  • Low/Moderate boundary: ≤7 supported by Kruskal-Wallis + Quartiles")
print("  • Moderate/High boundary: ≥13 supported by Youden + Tree + Tertiles")
print("  • Alternative high threshold: ≥14 from Kruskal-Wallis optimization")

#------------------------------------------------------------------------------
# LITERATURE REVIEW: PUBLISHED RISK CUTOFFS
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("PUBLISHED RISK CUTOFFS (for comparison)")
print("-"*70)
print("""
Source                  | Low Risk  | Moderate  | High Risk | Notes
------------------------|-----------|-----------|-----------|------------------
Segar 2024 (derivation) | —         | —         | —         | Continuous only (no cutoffs)
Alfonso 2025            | <7 (0-6)  | 7-12      | ≥13       | Only published 3-tier framework
Pandey 2025             | —         | —         | >12 (≥13) | Binary cutoff via Youden
Mauch 2025              | Q1: 0-6   | Q2-Q3     | Q4: ≥13   | Quartile-based
""")

#------------------------------------------------------------------------------
# CANDIDATE FRAMEWORKS
#------------------------------------------------------------------------------
print("-"*70)
print("CANDIDATE FRAMEWORKS FOR HEAD-TO-HEAD COMPARISON")
print("-"*70)
print("""
Framework A: ≤7 (Low), 8-12 (Moderate), ≥13 (High)
  • Low boundary: 25th percentile of cohort (Q1 = 7)
  • High boundary: Aligns with Youden, Decision Tree (13.5), Alfonso, Pandey, Mauch
  • Deviation from Alfonso: score=7 classified as "Low" instead of "Moderate"

Framework B: ≤7 (Low), 8-13 (Moderate), ≥14 (High)
  • Kruskal-Wallis optimal (H=260.0)
  • Does NOT align with any published framework

Framework C: ≤6 (Low), 7-12 (Moderate), ≥13 (High)
  • Exact match to Alfonso 2025 (<7 means scores 0-6)
""")

#------------------------------------------------------------------------------
# APPLY ALL THREE FRAMEWORKS
#------------------------------------------------------------------------------

def assign_framework_a(score):
    """Derived: ≤7, 8-12, ≥13"""
    if score <= 7:
        return 'Low'
    elif score <= 12:
        return 'Moderate'
    else:
        return 'High'

def assign_framework_b(score):
    """KW-optimal: ≤7, 8-13, ≥14"""
    if score <= 7:
        return 'Low'
    elif score <= 13:
        return 'Moderate'
    else:
        return 'High'

def assign_framework_c(score):
    """Alfonso exact: <7 (≤6), 7-12, ≥13"""
    if score < 7:
        return 'Low'
    elif score <= 12:
        return 'Moderate'
    else:
        return 'High'

# Apply to efficiency cohort
df_eff['risk_A'] = df_eff['ban_adhf_total_score'].apply(assign_framework_a)
df_eff['risk_B'] = df_eff['ban_adhf_total_score'].apply(assign_framework_b)
df_eff['risk_C'] = df_eff['ban_adhf_total_score'].apply(assign_framework_c)

# Apply to full cohort
df['risk_A'] = df['ban_adhf_total_score'].apply(assign_framework_a)
df['risk_B'] = df['ban_adhf_total_score'].apply(assign_framework_b)
df['risk_C'] = df['ban_adhf_total_score'].apply(assign_framework_c)

#------------------------------------------------------------------------------
# PRIMARY COMPARISON: 24-HOUR DIURETIC EFFICIENCY
#------------------------------------------------------------------------------
print("="*70)
print("PRIMARY OUTCOME: 24-HOUR DIURETIC EFFICIENCY (N=1,019)")
print("(BAN-ADHF's intended purpose)")
print("="*70)

results_24h = {}
for name, col in [('A: ≤7/8-12/≥13', 'risk_A'),
                   ('B: ≤7/8-13/≥14', 'risk_B'),
                   ('C: Alfonso (<7/7-12/≥13)', 'risk_C')]:
    groups = [df_eff[df_eff[col] == cat]['diuretic_efficiency_24h'].values
              for cat in ['Low', 'Moderate', 'High']]
    kw_h, kw_p = kruskal(*groups)
    medians = {cat: df_eff[df_eff[col] == cat]['diuretic_efficiency_24h'].median()
               for cat in ['Low', 'Moderate', 'High']}
    results_24h[name] = {'kw_h': kw_h, 'medians': medians}

print(f"\n{'Framework':<30} {'Low':<10} {'Mod':<10} {'High':<10} {'KW H':<10}")
print("-"*70)
for name, res in results_24h.items():
    print(f"{name:<30} {res['medians']['Low']:<10.1f} {res['medians']['Moderate']:<10.1f} {res['medians']['High']:<10.1f} {res['kw_h']:<10.1f}")

winner_24h = max(results_24h.keys(), key=lambda x: results_24h[x]['kw_h'])
print(f"\n→ Best separation for 24h efficiency: {winner_24h}")

#------------------------------------------------------------------------------
# SECONDARY CHECK: 72-HOUR DIURETIC EFFICIENCY
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("SECONDARY CHECK: 72-HOUR DIURETIC EFFICIENCY")
print("-"*70)

# Prepare 72h cohort
df_eff_72h = df[(df['icu_stay_ge_72h'] == 1) &
                (df['diuretic_efficiency_72h'].notna()) &
                (df['diuretic_efficiency_72h'] > 0)].copy()

df_eff_72h['risk_A'] = df_eff_72h['ban_adhf_total_score'].apply(assign_framework_a)
df_eff_72h['risk_B'] = df_eff_72h['ban_adhf_total_score'].apply(assign_framework_b)
df_eff_72h['risk_C'] = df_eff_72h['ban_adhf_total_score'].apply(assign_framework_c)

print(f"72h efficiency cohort: N={len(df_eff_72h):,}")

results_72h = {}
for name, col in [('A: ≤7/8-12/≥13', 'risk_A'),
                   ('B: ≤7/8-13/≥14', 'risk_B'),
                   ('C: Alfonso (<7/7-12/≥13)', 'risk_C')]:
    groups = [df_eff_72h[df_eff_72h[col] == cat]['diuretic_efficiency_72h'].values
              for cat in ['Low', 'Moderate', 'High']]
    kw_h, kw_p = kruskal(*groups)
    results_72h[name] = {'kw_h': kw_h}

print(f"\n{'Framework':<30} {'Kruskal-Wallis H':<20}")
print("-"*50)
for name, res in results_72h.items():
    print(f"{name:<30} {res['kw_h']:<20.1f}")

winner_72h = max(results_72h.keys(), key=lambda x: results_72h[x]['kw_h'])
print(f"\n→ Best separation for 72h efficiency: {winner_72h}")

#------------------------------------------------------------------------------
# EXPLORATORY: IN-HOSPITAL MORTALITY
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("EXPLORATORY: IN-HOSPITAL MORTALITY (N=1,505)")
print("(BAN-ADHF was NOT designed for mortality prediction)")
print("-"*70)

mort_results = {}
for name, col in [('A: ≤7/8-12/≥13', 'risk_A'),
                   ('B: ≤7/8-13/≥14', 'risk_B'),
                   ('C: Alfonso (<7/7-12/≥13)', 'risk_C')]:
    contingency = pd.crosstab(df[col], df['hospital_expire_flag'])
    chi2, p, _, _ = chi2_contingency(contingency)
    mort_rates = {cat: df[df[col] == cat]['hospital_expire_flag'].mean() * 100
                  for cat in ['Low', 'Moderate', 'High']}
    mort_results[name] = {'chi2': chi2, 'rates': mort_rates}

print(f"\n{'Framework':<30} {'Low %':<10} {'Mod %':<10} {'High %':<10} {'χ²':<10}")
print("-"*70)
for name, res in mort_results.items():
    print(f"{name:<30} {res['rates']['Low']:<10.1f} {res['rates']['Moderate']:<10.1f} {res['rates']['High']:<10.1f} {res['chi2']:<10.2f}")

print("\n→ Mortality: All frameworks similar (exploratory outcome only)")

#------------------------------------------------------------------------------
# DISTRIBUTION COMPARISON
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("DISTRIBUTION BY FRAMEWORK (Full Cohort, N=1,505)")
print("-"*70)

print(f"\n{'Framework':<30} {'Low':<15} {'Moderate':<15} {'High':<15}")
print("-"*75)
for name, col in [('A: ≤7/8-12/≥13', 'risk_A'),
                   ('B: ≤7/8-13/≥14', 'risk_B'),
                   ('C: Alfonso (<7/7-12/≥13)', 'risk_C')]:
    low_n = (df[col] == 'Low').sum()
    mod_n = (df[col] == 'Moderate').sum()
    high_n = (df[col] == 'High').sum()
    print(f"{name:<30} {low_n:>4} ({100*low_n/len(df):>5.1f}%)    {mod_n:>4} ({100*mod_n/len(df):>5.1f}%)    {high_n:>4} ({100*high_n/len(df):>5.1f}%)")

n_score_7 = (df['ban_adhf_total_score'] == 7).sum()
n_score_13 = (df['ban_adhf_total_score'] == 13).sum()
print(f"\nKey differences:")
print(f"  Score=7 (N={n_score_7}):  Low in A,B | Moderate in C (Alfonso)")
print(f"  Score=13 (N={n_score_13}): High in A,C | Moderate in B")

#------------------------------------------------------------------------------
# SCORE=7 ANALYSIS: Should it be Low or Moderate?
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("SCORE=7 ANALYSIS: Low (A,B) vs Moderate (C)?")
print("-"*70)

print("\n24h Efficiency by Adjacent Scores:")
print(f"{'Score':<8} {'N':<8} {'Median (mL/mg)':<15}")
print("-"*35)
for score in [5, 6, 7, 8, 9]:
    subset = df_eff[df_eff['ban_adhf_total_score'] == score]['diuretic_efficiency_24h']
    if len(subset) > 0:
        print(f"{score:<8} {len(subset):<8} {subset.median():<15.1f}")

# Compare to determine grouping
eff_6 = df_eff[df_eff['ban_adhf_total_score'] == 6]['diuretic_efficiency_24h'].median()
eff_7 = df_eff[df_eff['ban_adhf_total_score'] == 7]['diuretic_efficiency_24h'].median()
eff_8 = df_eff[df_eff['ban_adhf_total_score'] == 8]['diuretic_efficiency_24h'].median()

print(f"\n→ Score=7 efficiency ({eff_7:.1f}) is closer to score=6 ({eff_6:.1f}) than score=8 ({eff_8:.1f})")
print("→ Supports classifying score=7 as LOW (Frameworks A,B)")

#------------------------------------------------------------------------------
# FINAL SELECTION AND JUSTIFICATION
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("FINAL CUTOFF SELECTION")
print("="*70)

# Get values for justification
kw_a = results_24h['A: ≤7/8-12/≥13']['kw_h']
kw_b = results_24h['B: ≤7/8-13/≥14']['kw_h']
kw_c = results_24h['C: Alfonso (<7/7-12/≥13)']['kw_h']

print(f"""
SELECTED: Framework A — Low (≤7), Moderate (8-12), High (≥13)

JUSTIFICATION:

1. HIGH-RISK THRESHOLD (≥13):
   • Cell 5 Youden's Index: ≥13 optimal for efficiency quintile
   • Cell 5 Decision Tree: split at 13.5
   • Cell 5 Tertiles: 67th percentile = 13
   • Literature: Alfonso (≥13), Pandey (>12), Mauch Q4 (≥13)
   → Strong convergence across statistical methods and literature

2. LOW-RISK THRESHOLD (≤7):
   • Cell 5 Kruskal-Wallis: ≤7 optimal for continuous efficiency
   • Cell 5 Quartiles: 25th percentile = 7
   • Score=7 efficiency profile closer to score=6 than score=8
   → Data-driven threshold; minor deviation from Alfonso (<7)

3. WHY NOT FRAMEWORK B (≥14)?
   • Higher Kruskal-Wallis H ({kw_b:.1f} vs {kw_a:.1f})
   • BUT: No published framework uses ≥14
   • Marginal statistical gain vs. literature alignment trade-off
   → Literature concordance prioritized for peer review

4. WHY NOT EXACT ALFONSO (Framework C)?
   • Lower Kruskal-Wallis H ({kw_c:.1f} vs {kw_a:.1f})
   • Score=7 efficiency profile supports "Low" classification
   → Minor adaptation justified by cohort characteristics

5. TRANSPARENCY:
   • Alfonso uses <7 (score 7 = moderate)
   • We use ≤7 (score 7 = low) based on Q1 distribution
   • Difference affects {n_score_7} patients ({100*n_score_7/len(df):.1f}% of cohort)
""")

#------------------------------------------------------------------------------
# APPLY FINAL RISK CATEGORIES
#------------------------------------------------------------------------------
print("="*70)
print("APPLYING FINAL RISK CATEGORIES")
print("="*70)

df['risk_category'] = df['ban_adhf_total_score'].apply(assign_framework_a)
df['risk_category'] = pd.Categorical(
    df['risk_category'],
    categories=['Low', 'Moderate', 'High'],
    ordered=True
)

print("\nFinal Distribution (N=1,505):")
print(f"{'Category':<12} {'Score Range':<15} {'N':<10} {'%':<10}")
print("-"*47)
for cat, score_range in [('Low', '≤7'), ('Moderate', '8-12'), ('High', '≥13')]:
    n = (df['risk_category'] == cat).sum()
    pct = 100 * n / len(df)
    print(f"{cat:<12} {score_range:<15} {n:<10} {pct:<10.1f}")

# Clean up
df.drop(['risk_A', 'risk_B', 'risk_C'], axis=1, inplace=True)
df_eff.drop(['risk_A', 'risk_B', 'risk_C'], axis=1, inplace=True)

print("\n✓ Risk categories applied to dataframe as 'risk_category'")
print("\n→ Next: Define analysis sub-cohorts")

In [None]:
#==========================================================================
# SECTION 2: RISK CATEGORY DERIVATION
# Cell 7: Define Analysis Sub-Cohorts
#==========================================================================
# PURPOSE: Define all sub-cohorts needed for analysis per master plan
# OUTPUT: Sub-cohort dataframes with sample sizes and characteristics
#==========================================================================

print("="*70)
print("ANALYSIS SUB-COHORT DEFINITIONS")
print("="*70)

#------------------------------------------------------------------------------
# SUB-COHORT 1: FULL COHORT (Mortality Analysis)
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("1. FULL COHORT — Mortality Analysis")
print("-"*70)
n_full = len(df)
n_deaths = df['hospital_expire_flag'].sum()
print(f"   N = {n_full:,}")
print(f"   Deaths: {n_deaths} ({100*n_deaths/n_full:.1f}%)")
print(f"   Use: In-hospital mortality (exploratory)")

#------------------------------------------------------------------------------
# SUB-COHORT 2: ICU STAY ≥24H (Binary Diuretic Resistance)
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("2. ICU STAY ≥24H — Binary Diuretic Resistance")
print("-"*70)
df_icu_24h = df[df['icu_stay_ge_24h'] == 1].copy()
print(f"   N = {len(df_icu_24h):,}")
print(f"   Exclusions: {n_full - len(df_icu_24h)} patients with ICU stay <24h")

# Verify/create diuretic_resistance variable
print("\n   Diuretic Resistance Definition:")
if 'diuretic_resistance' in df.columns:
    # Check if it matches expected definition (urine ≤3,000 mL)
    # Verify by comparing with urine_output_24h_ml
    if 'urine_output_24h_ml' in df.columns:
        expected_dr = (df['urine_output_24h_ml'] <= 3000).astype(int)
        match_pct = (df['diuretic_resistance'] == expected_dr).mean() * 100
        print(f"   • Pre-calculated 'diuretic_resistance' column exists")
        print(f"   • Verification against urine ≤3,000 mL: {match_pct:.1f}% match")

        if match_pct < 100:
            print(f"   • NOTE: Some discrepancy detected - creating verified variable")
            df['diuretic_resistance_verified'] = expected_dr

    dr_n = df_icu_24h['diuretic_resistance'].sum()
    dr_pct = 100 * dr_n / len(df_icu_24h)
    print(f"   • DR prevalence (ICU≥24h): {dr_n} ({dr_pct:.1f}%)")
else:
    # Create diuretic_resistance if not present
    print(f"   • Creating 'diuretic_resistance' from urine_output_24h_ml ≤3,000 mL")
    df['diuretic_resistance'] = (df['urine_output_24h_ml'] <= 3000).astype(int)
    dr_n = df_icu_24h['diuretic_resistance'].sum()
    dr_pct = 100 * dr_n / len(df_icu_24h)
    print(f"   • DR prevalence (ICU≥24h): {dr_n} ({dr_pct:.1f}%)")

print(f"\n   Use: Binary DR outcome (urine output ≤3,000 mL in first 24h)")

#------------------------------------------------------------------------------
# SUB-COHORT 3: 24H EFFICIENCY COHORT (Primary Diuretic Efficiency)
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("3. 24H EFFICIENCY COHORT — Primary Diuretic Efficiency Analysis")
print("-"*70)
df_24h_eff = df[(df['icu_stay_ge_24h'] == 1) &
                (df['diuretic_efficiency_24h'].notna()) &
                (df['diuretic_efficiency_24h'] > 0)].copy()

print(f"   N = {len(df_24h_eff):,}")
print(f"   Criteria: ICU ≥24h AND 24h IV diuretic dose >0")
print(f"   Exclusions from ICU≥24h: {len(df_icu_24h) - len(df_24h_eff)} (no IV diuretics in first 24h)")
print(f"   Median efficiency: {df_24h_eff['diuretic_efficiency_24h'].median():.1f} mL/mg")
print(f"   Use: Spearman correlation, C-index, quintile/quartile AUROC")

# Define and SAVE quintile and quartile thresholds for later use
QUINTILE_THRESHOLD_24H = df_24h_eff['diuretic_efficiency_24h'].quantile(0.20)
QUARTILE_THRESHOLD_24H = df_24h_eff['diuretic_efficiency_24h'].quantile(0.25)

print(f"\n   Efficiency Thresholds (SAVED FOR LATER USE):")
print(f"   • QUINTILE_THRESHOLD_24H (20th %ile): ≤{QUINTILE_THRESHOLD_24H:.1f} mL/mg")
print(f"   • QUARTILE_THRESHOLD_24H (25th %ile): ≤{QUARTILE_THRESHOLD_24H:.1f} mL/mg")

# Create binary outcome variables for AUROC analysis
df_24h_eff['lowest_quintile_24h'] = (df_24h_eff['diuretic_efficiency_24h'] <= QUINTILE_THRESHOLD_24H).astype(int)
df_24h_eff['lowest_quartile_24h'] = (df_24h_eff['diuretic_efficiency_24h'] <= QUARTILE_THRESHOLD_24H).astype(int)

n_quintile = df_24h_eff['lowest_quintile_24h'].sum()
n_quartile = df_24h_eff['lowest_quartile_24h'].sum()
print(f"\n   Binary Outcomes Created:")
print(f"   • Lowest quintile: {n_quintile} ({100*n_quintile/len(df_24h_eff):.1f}%)")
print(f"   • Lowest quartile: {n_quartile} ({100*n_quartile/len(df_24h_eff):.1f}%)")

#------------------------------------------------------------------------------
# SUB-COHORT 4: ICU STAY ≥72H
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("4. ICU STAY ≥72H")
print("-"*70)
df_icu_72h = df[df['icu_stay_ge_72h'] == 1].copy()
print(f"   N = {len(df_icu_72h):,}")
print(f"   Exclusions: {n_full - len(df_icu_72h)} patients with ICU stay <72h")

#------------------------------------------------------------------------------
# SUB-COHORT 5: 72H EFFICIENCY COHORT
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("5. 72H EFFICIENCY COHORT — Secondary Diuretic Efficiency Analysis")
print("-"*70)
df_72h_eff = df[(df['icu_stay_ge_72h'] == 1) &
                (df['diuretic_efficiency_72h'].notna()) &
                (df['diuretic_efficiency_72h'] > 0)].copy()

print(f"   N = {len(df_72h_eff):,}")
print(f"   Criteria: ICU ≥72h AND 72h IV diuretic dose >0")
print(f"   Exclusions from ICU≥72h: {len(df_icu_72h) - len(df_72h_eff)} (no IV diuretics in first 72h)")
print(f"   Median efficiency: {df_72h_eff['diuretic_efficiency_72h'].median():.1f} mL/mg")
print(f"   Use: Spearman correlation, C-index (matches original derivation endpoint)")

# Save 72h thresholds as well
QUINTILE_THRESHOLD_72H = df_72h_eff['diuretic_efficiency_72h'].quantile(0.20)
QUARTILE_THRESHOLD_72H = df_72h_eff['diuretic_efficiency_72h'].quantile(0.25)

print(f"\n   Efficiency Thresholds (SAVED FOR LATER USE):")
print(f"   • QUINTILE_THRESHOLD_72H (20th %ile): ≤{QUINTILE_THRESHOLD_72H:.1f} mL/mg")
print(f"   • QUARTILE_THRESHOLD_72H (25th %ile): ≤{QUARTILE_THRESHOLD_72H:.1f} mL/mg")

#------------------------------------------------------------------------------
# SUB-COHORT 6: CARDIOGENIC SHOCK
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("6. CARDIOGENIC SHOCK SUBGROUP")
print("-"*70)
df_cs = df[df['cardiogenic_shock'] == 1].copy()
df_no_cs = df[df['cardiogenic_shock'] == 0].copy()

print(f"   Cardiogenic shock (CS):    N = {len(df_cs):,} ({100*len(df_cs)/n_full:.1f}%)")
print(f"   No cardiogenic shock:      N = {len(df_no_cs):,} ({100*len(df_no_cs)/n_full:.1f}%)")
print(f"   CS mortality: {100*df_cs['hospital_expire_flag'].mean():.1f}%")
print(f"   Non-CS mortality: {100*df_no_cs['hospital_expire_flag'].mean():.1f}%")
print(f"   Use: Subgroup analysis, sensitivity analysis")

#------------------------------------------------------------------------------
# SUB-COHORT 7: COMPLETE LVEF DATA
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("7. COMPLETE LVEF DATA")
print("-"*70)
df_lvef = df[df['lvef'].notna()].copy()
print(f"   N = {len(df_lvef):,} ({100*len(df_lvef)/n_full:.1f}% of cohort)")
print(f"   Missing LVEF: {n_full - len(df_lvef)}")

if 'hf_phenotype' in df.columns:
    print(f"\n   HF Phenotype Distribution:")
    for phenotype in df_lvef['hf_phenotype'].dropna().unique():
        n = (df_lvef['hf_phenotype'] == phenotype).sum()
        print(f"   • {phenotype}: {n} ({100*n/len(df_lvef):.1f}%)")

#------------------------------------------------------------------------------
# SUB-COHORT 8: EXCLUDE ADVANCED CKD (Sensitivity Analysis)
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("8. EXCLUDE ADVANCED CKD — Sensitivity Analysis")
print("-"*70)
df_no_ckd = df[df['chronic_advanced_ckd'] == 0].copy()
print(f"   N = {len(df_no_ckd):,}")
print(f"   Excluded: {n_full - len(df_no_ckd)} with chronic advanced CKD")
print(f"   Use: Sensitivity analysis (different pathophysiology)")

#------------------------------------------------------------------------------
# SUMMARY TABLE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUB-COHORT SUMMARY")
print("="*70)

print(f"\n{'Sub-Cohort':<35} {'N':<10} {'Purpose':<30}")
print("-"*75)
print(f"{'Full cohort':<35} {n_full:<10} {'Mortality (exploratory)':<30}")
print(f"{'ICU ≥24h':<35} {len(df_icu_24h):<10} {'Binary diuretic resistance':<30}")
print(f"{'24h efficiency cohort':<35} {len(df_24h_eff):<10} {'Correlation, C-index, AUROC':<30}")
print(f"{'ICU ≥72h':<35} {len(df_icu_72h):<10} {'72h outcomes':<30}")
print(f"{'72h efficiency cohort':<35} {len(df_72h_eff):<10} {'72h correlation, C-index':<30}")
print(f"{'Cardiogenic shock':<35} {len(df_cs):<10} {'Subgroup analysis':<30}")
print(f"{'No cardiogenic shock':<35} {len(df_no_cs):<10} {'Sensitivity analysis':<30}")
print(f"{'Complete LVEF':<35} {len(df_lvef):<10} {'HF phenotype subgroups':<30}")
print(f"{'Exclude advanced CKD':<35} {len(df_no_ckd):<10} {'Sensitivity analysis':<30}")

#------------------------------------------------------------------------------
# SAVED THRESHOLDS SUMMARY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SAVED THRESHOLDS FOR AUROC ANALYSIS")
print("="*70)
print(f"\n24-Hour Efficiency (N={len(df_24h_eff):,}):")
print(f"   QUINTILE_THRESHOLD_24H = {QUINTILE_THRESHOLD_24H:.1f} mL/mg (20th percentile)")
print(f"   QUARTILE_THRESHOLD_24H = {QUARTILE_THRESHOLD_24H:.1f} mL/mg (25th percentile)")
print(f"\n72-Hour Efficiency (N={len(df_72h_eff):,}):")
print(f"   QUINTILE_THRESHOLD_72H = {QUINTILE_THRESHOLD_72H:.1f} mL/mg (20th percentile)")
print(f"   QUARTILE_THRESHOLD_72H = {QUARTILE_THRESHOLD_72H:.1f} mL/mg (25th percentile)")

#------------------------------------------------------------------------------
# RISK CATEGORY DISTRIBUTION BY KEY SUB-COHORTS
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("RISK CATEGORY DISTRIBUTION BY KEY SUB-COHORTS")
print("-"*70)

def show_risk_dist(data, name):
    print(f"\n{name} (N={len(data):,}):")
    for cat in ['Low', 'Moderate', 'High']:
        n = (data['risk_category'] == cat).sum()
        pct = 100 * n / len(data)
        print(f"   {cat:<10}: {n:>4} ({pct:>5.1f}%)")

show_risk_dist(df, "Full Cohort")
show_risk_dist(df_24h_eff, "24h Efficiency Cohort")
show_risk_dist(df_72h_eff, "72h Efficiency Cohort")

print("\n" + "="*70)
print("SUB-COHORT DEFINITIONS COMPLETE")
print("="*70)
print("\nKey variables created/saved:")
print("  • QUINTILE_THRESHOLD_24H, QUARTILE_THRESHOLD_24H")
print("  • QUINTILE_THRESHOLD_72H, QUARTILE_THRESHOLD_72H")
print("  • df_24h_eff['lowest_quintile_24h'], df_24h_eff['lowest_quartile_24h']")
print("  • diuretic_resistance verified")
print("\n→ Next: Generate Table 1 (Baseline Characteristics)")

In [None]:
#==========================================================================
# SECTION 3: TABLE 1 - BASELINE CHARACTERISTICS
# Cell 8: Mortality Status
#==========================================================================
# PURPOSE: Create Table showing baseline characteristics stratified by
#          in-hospital mortality (Survivors vs Non-survivors)
# NOTE: Diuretic resistance is an OUTCOME, shown in Table 3, not here
#==========================================================================

from tableone import TableOne
import warnings
warnings.filterwarnings('ignore')

print("="*70)
print("TABLE 1: BASELINE CHARACTERISTICS BY MORTALITY STATUS")
print("="*70)

#------------------------------------------------------------------------------
# PREPARE DATA
#------------------------------------------------------------------------------

df_table1 = df.copy()

# Create mortality status variable
df_table1['mortality_status'] = df_table1['hospital_expire_flag'].map({0: 'Survivors', 1: 'Non-survivors'})

# Create male_sex variable
if df_table1['gender'].dtype == 'object':
    df_table1['male_sex'] = (df_table1['gender'] == 'M').astype(int)
else:
    df_table1['male_sex'] = df_table1['gender']

# Risk category as string
df_table1['risk_category'] = df_table1['risk_category'].astype(str)

#------------------------------------------------------------------------------
# DEFINE VARIABLES FOR TABLE 1 (Baseline characteristics only)
#------------------------------------------------------------------------------

all_columns = [
    # Demographics
    'age',
    'male_sex',
    # BAN-ADHF Score
    'ban_adhf_total_score',
    'risk_category',
    # Score Components (raw values)
    'creatinine',
    'bun',
    'ntprobnp',
    'dbp',
    'total_furosemide_equivalent_mg',
    'hx_atrial_fibrillation',
    'hx_hypertension',
    'prior_hf_hospitalization_12mo',
    # Heart Failure Characteristics
    'lvef',
    'hf_phenotype',
    # Comorbidities
    'hx_diabetes',
    'hx_renal_disease',
    'hx_myocardial_infarction',
    'hx_stroke',
    'hx_copd',
    'cci_score',
    # Clinical Presentation
    'cardiogenic_shock',
    'invasive_vent'
]

# Filter to existing columns
all_columns = [v for v in all_columns if v in df_table1.columns]

# Categorical variables
categorical_vars = [
    'male_sex',
    'risk_category',
    'hx_atrial_fibrillation',
    'hx_hypertension',
    'prior_hf_hospitalization_12mo',
    'hf_phenotype',
    'hx_diabetes',
    'hx_renal_disease',
    'hx_myocardial_infarction',
    'hx_stroke',
    'hx_copd',
    'cardiogenic_shock',
    'invasive_vent'
]
categorical_vars = [v for v in categorical_vars if v in df_table1.columns]

# Non-normal continuous variables
nonnormal_vars = [
    'ban_adhf_total_score',
    'creatinine',
    'bun',
    'ntprobnp',
    'total_furosemide_equivalent_mg',
    'cci_score'
]
nonnormal_vars = [v for v in nonnormal_vars if v in df_table1.columns]

print(f"Total variables: {len(all_columns)}")

#------------------------------------------------------------------------------
# LABELS
#------------------------------------------------------------------------------

labels = {
    'age': 'Age, years',
    'male_sex': 'Male sex',
    'ban_adhf_total_score': 'Total score',
    'risk_category': 'Risk category',
    'creatinine': 'Creatinine, mg/dL',
    'bun': 'BUN, mg/dL',
    'ntprobnp': 'NT-proBNP, pg/mL',
    'dbp': 'Diastolic BP, mmHg',
    'total_furosemide_equivalent_mg': 'Home diuretic dose, mg/day †',
    'hx_atrial_fibrillation': 'Atrial fibrillation',
    'hx_hypertension': 'Hypertension',
    'prior_hf_hospitalization_12mo': 'Prior HF hospitalization (12 months)',
    'lvef': 'LVEF, % ‡',
    'hf_phenotype': 'HF phenotype ‡',
    'hx_diabetes': 'Diabetes mellitus',
    'hx_renal_disease': 'Chronic kidney disease',
    'hx_myocardial_infarction': 'Prior myocardial infarction',
    'hx_stroke': 'Prior stroke',
    'hx_copd': 'COPD',
    'cci_score': 'Charlson Comorbidity Index',
    'cardiogenic_shock': 'Cardiogenic shock',
    'invasive_vent': 'Invasive mechanical ventilation'
}

#------------------------------------------------------------------------------
# GENERATE TABLE
#------------------------------------------------------------------------------

order = {
    'mortality_status': ['Survivors', 'Non-survivors'],
    'risk_category': ['Low', 'Moderate', 'High']
}

limit = {
    'male_sex': 1,
    'hx_atrial_fibrillation': 1,
    'hx_hypertension': 1,
    'prior_hf_hospitalization_12mo': 1,
    'hx_diabetes': 1,
    'hx_renal_disease': 1,
    'hx_myocardial_infarction': 1,
    'hx_stroke': 1,
    'hx_copd': 1,
    'cardiogenic_shock': 1,
    'invasive_vent': 1
}

table1 = TableOne(
    df_table1,
    columns=all_columns,
    categorical=categorical_vars,
    nonnormal=nonnormal_vars,
    groupby='mortality_status',
    pval=True,
    rename=labels,
    missing=False,
    overall=True,
    order=order,
    limit=limit,
    decimals=1
)

print("\n")
print(table1.tabulate(tablefmt="simple"))

#------------------------------------------------------------------------------
# EXPORT TABLE
#------------------------------------------------------------------------------

table1_df = table1.tableone
table1_df.to_csv('/content/Table1_by_mortality.csv')
print("\n✓ Table 1 saved to 'Table1_by_mortality.csv'")

table1_df.to_excel('/content/Table1_by_mortality.xlsx')
print("✓ Table 1 saved to 'Table1_by_mortality.xlsx'")

#------------------------------------------------------------------------------
# VERIFICATION
#------------------------------------------------------------------------------

print("\n" + "="*70)
print("VERIFICATION")
print("="*70)

survivors = df_table1[df_table1['mortality_status'] == 'Survivors']
nonsurvivors = df_table1[df_table1['mortality_status'] == 'Non-survivors']

print(f"""
Sample Size:
  • Overall:        N = {len(df_table1):,}
  • Survivors:      N = {len(survivors):,}
  • Non-survivors:  N = {len(nonsurvivors):,}

BAN-ADHF Score:
  • Survivors:      {survivors['ban_adhf_total_score'].mean():.1f} ± {survivors['ban_adhf_total_score'].std():.1f}
  • Non-survivors:  {nonsurvivors['ban_adhf_total_score'].mean():.1f} ± {nonsurvivors['ban_adhf_total_score'].std():.1f}

Risk Category (Non-survivors):
  • Low (≤7):       {(nonsurvivors['risk_category']=='Low').sum()} ({100*(nonsurvivors['risk_category']=='Low').mean():.1f}%)
  • Moderate (8-12): {(nonsurvivors['risk_category']=='Moderate').sum()} ({100*(nonsurvivors['risk_category']=='Moderate').mean():.1f}%)
  • High (≥13):     {(nonsurvivors['risk_category']=='High').sum()} ({100*(nonsurvivors['risk_category']=='High').mean():.1f}%)
""")

#------------------------------------------------------------------------------
# FOOTNOTES
#------------------------------------------------------------------------------

lvef_n = df_table1['lvef'].notna().sum()
lvef_pct = 100 * lvef_n / len(df_table1)

print("="*70)
print("TABLE 1 FOOTNOTES")
print("="*70)
print(f"""
Values are mean ± SD, median [IQR], or n (%). P-values from Student's t-test
or Mann-Whitney U test for continuous variables and chi-square test for
categorical variables.

† Home diuretic dose expressed as oral furosemide equivalents.
‡ Available in {lvef_n} patients ({lvef_pct:.1f}%); percentages calculated
  among those with available data.

Abbreviations: BUN, blood urea nitrogen; BP, blood pressure; COPD, chronic
obstructive pulmonary disease; HF, heart failure; HFmrEF, heart failure with
mildly reduced ejection fraction; HFpEF, heart failure with preserved ejection
fraction; HFrEF, heart failure with reduced ejection fraction; IQR, interquartile
range; LVEF, left ventricular ejection fraction; NT-proBNP, N-terminal pro-B-type
natriuretic peptide.
""")

print("→ Next: Primary Analysis - 24h Diuretic Efficiency")

In [None]:
#==========================================================================
# SECTION 4: CO-PRIMARY OUTCOME 1 - 24-HOUR DIURETIC EFFICIENCY
# Cell 9: Correlation, C-index, and AUROC Analysis
#==========================================================================
# PURPOSE: Evaluate BAN-ADHF score's ability to predict 24h diuretic efficiency
# POPULATION: N=1,019 (ICU ≥24h with IV diuretic dose >0)
# METRICS: Spearman correlation (primary), Pearson (for Mauch comparison),
#          C-index, AUROC (quintile/quartile)
# OUTPUT: Primary discrimination statistics with 95% CIs
#==========================================================================

from scipy.stats import spearmanr, pearsonr, kruskal
from sklearn.metrics import roc_auc_score, roc_curve
from lifelines.utils import concordance_index
import numpy as np

print("="*70)
print("CO-PRIMARY OUTCOME 1: 24-HOUR DIURETIC EFFICIENCY")
print("="*70)
print("\nBAN-ADHF was designed to predict diuretic efficiency.")
print("This is the score's PRIMARY intended purpose.")

#------------------------------------------------------------------------------
# ANALYSIS COHORT
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("ANALYSIS COHORT")
print("-"*70)

df_24h = df[(df['icu_stay_ge_24h'] == 1) &
            (df['diuretic_efficiency_24h'].notna()) &
            (df['diuretic_efficiency_24h'] > 0)].copy()

print(f"Population: ICU ≥24h with IV diuretic dose >0 in first 24 hours")
print(f"N = {len(df_24h):,}")
print(f"\nDiuretic Efficiency Distribution:")
print(f"  Mean ± SD:    {df_24h['diuretic_efficiency_24h'].mean():.1f} ± {df_24h['diuretic_efficiency_24h'].std():.1f} mL/mg")
print(f"  Median (IQR): {df_24h['diuretic_efficiency_24h'].median():.1f} ({df_24h['diuretic_efficiency_24h'].quantile(0.25):.1f}-{df_24h['diuretic_efficiency_24h'].quantile(0.75):.1f}) mL/mg")
print(f"  Range:        {df_24h['diuretic_efficiency_24h'].min():.1f} - {df_24h['diuretic_efficiency_24h'].max():.1f} mL/mg")

# Check skewness
skewness = df_24h['diuretic_efficiency_24h'].skew()
print(f"  Skewness:     {skewness:.2f} ({'right-skewed' if skewness > 0.5 else 'approximately normal'})")
print(f"\n  → Spearman ρ used as primary metric (robust to skewed distributions)")

#------------------------------------------------------------------------------
# BOOTSTRAP SETUP
#------------------------------------------------------------------------------
np.random.seed(42)
n_bootstrap = 1000

#------------------------------------------------------------------------------
# 1. SPEARMAN CORRELATION (PRIMARY)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("1. SPEARMAN CORRELATION (PRIMARY METRIC)")
print("="*70)
print("Higher BAN-ADHF score should predict LOWER diuretic efficiency")
print("(inverse correlation expected)")

# Calculate Spearman correlation
rho, rho_p = spearmanr(df_24h['ban_adhf_total_score'],
                        df_24h['diuretic_efficiency_24h'])

# Bootstrap 95% CI for Spearman
rho_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_24h), size=len(df_24h), replace=True)
    sample = df_24h.iloc[idx]
    r, _ = spearmanr(sample['ban_adhf_total_score'], sample['diuretic_efficiency_24h'])
    rho_bootstrap.append(r)

rho_ci_lower = np.percentile(rho_bootstrap, 2.5)
rho_ci_upper = np.percentile(rho_bootstrap, 97.5)

print(f"\nSpearman ρ = {rho:.3f} (95% CI: {rho_ci_lower:.3f} to {rho_ci_upper:.3f})")
print(f"P-value: < 0.001" if rho_p < 0.001 else f"P-value: {rho_p:.4f}")

# Interpretation
print(f"\nInterpretation:")
if abs(rho) > 0.5:
    print(f"  Strong inverse correlation (|ρ| > 0.5)")
elif abs(rho) > 0.3:
    print(f"  Moderate inverse correlation (|ρ| 0.3-0.5)")
else:
    print(f"  Weak inverse correlation (|ρ| < 0.3)")

#------------------------------------------------------------------------------
# 1b. PEARSON CORRELATION (FOR MAUCH COMPARISON)
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("1b. PEARSON CORRELATION (For Mauch 2025 Comparison)")
print("-"*70)

# Calculate Pearson correlation
pearson_r, pearson_p = pearsonr(df_24h['ban_adhf_total_score'],
                                 df_24h['diuretic_efficiency_24h'])

# Bootstrap CI for Pearson
pearson_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_24h), size=len(df_24h), replace=True)
    sample = df_24h.iloc[idx]
    r, _ = pearsonr(sample['ban_adhf_total_score'], sample['diuretic_efficiency_24h'])
    pearson_bootstrap.append(r)

pearson_ci_lower = np.percentile(pearson_bootstrap, 2.5)
pearson_ci_upper = np.percentile(pearson_bootstrap, 97.5)

print(f"\nPearson r = {pearson_r:.3f} (95% CI: {pearson_ci_lower:.3f} to {pearson_ci_upper:.3f})")
print(f"Mauch 2025 (floor patients): r = -0.40")

# Calculate percentage difference from Mauch
mauch_r = -0.40
pct_diff = ((abs(pearson_r) - abs(mauch_r)) / abs(mauch_r)) * 100
print(f"\n→ Our Pearson r is {abs(pct_diff):.1f}% {'stronger' if abs(pearson_r) > abs(mauch_r) else 'weaker'} than Mauch")
print(f"\nNote: Spearman ρ is primary metric due to right-skewed efficiency distribution")

#------------------------------------------------------------------------------
# 2. C-INDEX (HARRELL'S CONCORDANCE INDEX)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("2. C-INDEX (HARRELL'S CONCORDANCE INDEX)")
print("="*70)
print("Probability that patient with higher score has lower efficiency")

# Calculate C-index
c_index = concordance_index(
    df_24h['diuretic_efficiency_24h'],
    -df_24h['ban_adhf_total_score'],
    np.ones(len(df_24h))
)

# Bootstrap 95% CI for C-index
c_index_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_24h), size=len(df_24h), replace=True)
    sample = df_24h.iloc[idx]
    c = concordance_index(
        sample['diuretic_efficiency_24h'],
        -sample['ban_adhf_total_score'],
        np.ones(len(sample))
    )
    c_index_bootstrap.append(c)

c_index_ci_lower = np.percentile(c_index_bootstrap, 2.5)
c_index_ci_upper = np.percentile(c_index_bootstrap, 97.5)

print(f"\nC-index = {c_index:.3f} (95% CI: {c_index_ci_lower:.3f} to {c_index_ci_upper:.3f})")

print(f"\nInterpretation:")
if c_index >= 0.7:
    print(f"  Good discrimination (C-index ≥ 0.7)")
elif c_index >= 0.6:
    print(f"  Moderate discrimination (C-index 0.6-0.7)")
else:
    print(f"  Poor discrimination (C-index < 0.6)")

#------------------------------------------------------------------------------
# 3. AUROC FOR LOWEST QUINTILE (Bottom 20%)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("3. AUROC FOR LOWEST QUINTILE OF EFFICIENCY")
print("="*70)

quintile_threshold = df_24h['diuretic_efficiency_24h'].quantile(0.20)
df_24h['lowest_quintile'] = (df_24h['diuretic_efficiency_24h'] <= quintile_threshold).astype(int)

n_quintile = df_24h['lowest_quintile'].sum()
print(f"Threshold: ≤{quintile_threshold:.1f} mL/mg (20th percentile)")
print(f"Patients in lowest quintile: {n_quintile} ({100*n_quintile/len(df_24h):.1f}%)")

auroc_quintile = roc_auc_score(df_24h['lowest_quintile'], df_24h['ban_adhf_total_score'])

# Bootstrap 95% CI
auroc_q_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_24h), size=len(df_24h), replace=True)
    sample = df_24h.iloc[idx]
    if sample['lowest_quintile'].sum() > 0 and sample['lowest_quintile'].sum() < len(sample):
        auc = roc_auc_score(sample['lowest_quintile'], sample['ban_adhf_total_score'])
        auroc_q_bootstrap.append(auc)

auroc_q_ci_lower = np.percentile(auroc_q_bootstrap, 2.5)
auroc_q_ci_upper = np.percentile(auroc_q_bootstrap, 97.5)

print(f"\nAUROC = {auroc_quintile:.3f} (95% CI: {auroc_q_ci_lower:.3f} to {auroc_q_ci_upper:.3f})")
print(f"Segar 2024 (trial validation): AUROC = 0.84")

segar_diff = auroc_quintile - 0.84
print(f"→ Difference from Segar: {segar_diff:+.3f}")

#------------------------------------------------------------------------------
# 4. AUROC FOR LOWEST QUARTILE (Bottom 25%)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("4. AUROC FOR LOWEST QUARTILE OF EFFICIENCY")
print("="*70)

quartile_threshold = df_24h['diuretic_efficiency_24h'].quantile(0.25)
df_24h['lowest_quartile'] = (df_24h['diuretic_efficiency_24h'] <= quartile_threshold).astype(int)

n_quartile = df_24h['lowest_quartile'].sum()
print(f"Threshold: ≤{quartile_threshold:.1f} mL/mg (25th percentile)")
print(f"Patients in lowest quartile: {n_quartile} ({100*n_quartile/len(df_24h):.1f}%)")

auroc_quartile = roc_auc_score(df_24h['lowest_quartile'], df_24h['ban_adhf_total_score'])

# Bootstrap 95% CI
auroc_qt_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_24h), size=len(df_24h), replace=True)
    sample = df_24h.iloc[idx]
    if sample['lowest_quartile'].sum() > 0 and sample['lowest_quartile'].sum() < len(sample):
        auc = roc_auc_score(sample['lowest_quartile'], sample['ban_adhf_total_score'])
        auroc_qt_bootstrap.append(auc)

auroc_qt_ci_lower = np.percentile(auroc_qt_bootstrap, 2.5)
auroc_qt_ci_upper = np.percentile(auroc_qt_bootstrap, 97.5)

print(f"\nAUROC = {auroc_quartile:.3f} (95% CI: {auroc_qt_ci_lower:.3f} to {auroc_qt_ci_upper:.3f})")
print(f"Pandey 2025 (CLOROTIC trial): AUROC = 0.70")

pandey_diff = auroc_quartile - 0.70
print(f"→ Difference from Pandey: {pandey_diff:+.3f} ({100*pandey_diff/0.70:+.1f}%)")

#------------------------------------------------------------------------------
# 5. EFFICIENCY BY RISK CATEGORY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("5. DIURETIC EFFICIENCY BY RISK CATEGORY")
print("="*70)

print(f"\n{'Risk Category':<15} {'N':<8} {'Median (IQR)':<25} {'Mean ± SD':<20}")
print("-"*70)

for cat in ['Low', 'Moderate', 'High']:
    subset = df_24h[df_24h['risk_category'] == cat]['diuretic_efficiency_24h']
    n = len(subset)
    median = subset.median()
    q1 = subset.quantile(0.25)
    q3 = subset.quantile(0.75)
    mean = subset.mean()
    std = subset.std()
    print(f"{cat:<15} {n:<8} {median:.1f} ({q1:.1f}-{q3:.1f}){'':>8} {mean:.1f} ± {std:.1f}")

# Kruskal-Wallis test
groups = [df_24h[df_24h['risk_category'] == cat]['diuretic_efficiency_24h'].values
          for cat in ['Low', 'Moderate', 'High']]
kw_stat, kw_p = kruskal(*groups)
print(f"\nKruskal-Wallis H = {kw_stat:.1f}, p < 0.001" if kw_p < 0.001 else f"\nKruskal-Wallis H = {kw_stat:.1f}, p = {kw_p:.3f}")

#------------------------------------------------------------------------------
# LITERATURE COMPARISON TABLE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("LITERATURE COMPARISON")
print("="*70)

print(f"""
Study                | Population       | N     | Spearman ρ | Pearson r | AUROC (Q5) | AUROC (Q4)
---------------------|------------------|-------|------------|-----------|------------|------------
Segar 2024           | Trial derivation | 707   | —          | —         | 0.84       | —
Mauch 2025           | Floor patients   | 317   | —          | -0.40     | —          | —
Pandey 2025          | CLOROTIC trial   | 220   | —          | —         | —          | 0.70
---------------------|------------------|-------|------------|-----------|------------|------------
THIS STUDY           | ICU patients     | {len(df_24h):,}   | {rho:.3f}      | {pearson_r:.3f}     | {auroc_quintile:.3f}      | {auroc_quartile:.3f}

Key Findings:
- Spearman ρ = {rho:.3f} — strongest correlation among all validations
- Pearson r = {pearson_r:.3f} — {abs(pct_diff):.0f}% {'stronger' if abs(pearson_r) > abs(mauch_r) else 'weaker'} than Mauch's r = -0.40
- AUROC (quintile) = {auroc_quintile:.3f} — {'comparable to' if abs(auroc_quintile - 0.84) < 0.05 else 'lower than'} Segar's 0.84
- AUROC (quartile) = {auroc_quartile:.3f} — {'exceeds' if auroc_quartile > 0.70 else 'comparable to'} Pandey's 0.70 by {pandey_diff:+.3f}
""")

#------------------------------------------------------------------------------
# SUMMARY TABLE
#------------------------------------------------------------------------------
print("="*70)
print("SUMMARY: 24-HOUR DIURETIC EFFICIENCY DISCRIMINATION")
print("="*70)

print(f"""
Metric                              Value           95% CI                  Reference
---------------------------------------------------------------------------------------------
Spearman ρ (primary)                {rho:.3f}          ({rho_ci_lower:.3f} to {rho_ci_upper:.3f})       —
Pearson r (Mauch comparison)        {pearson_r:.3f}          ({pearson_ci_lower:.3f} to {pearson_ci_upper:.3f})       Mauch: -0.40
C-index                             {c_index:.3f}          ({c_index_ci_lower:.3f} to {c_index_ci_upper:.3f})       —
AUROC (lowest quintile)             {auroc_quintile:.3f}          ({auroc_q_ci_lower:.3f} to {auroc_q_ci_upper:.3f})       Segar: 0.84
AUROC (lowest quartile)             {auroc_quartile:.3f}          ({auroc_qt_ci_lower:.3f} to {auroc_qt_ci_upper:.3f})       Pandey: 0.70
""")

#------------------------------------------------------------------------------
# STORE RESULTS FOR USE IN SUBSEQUENT CELLS
#------------------------------------------------------------------------------

results_24h = {
    'n': len(df_24h),
    'spearman_rho': rho,
    'spearman_ci': (rho_ci_lower, rho_ci_upper),
    'pearson_r': pearson_r,
    'pearson_ci': (pearson_ci_lower, pearson_ci_upper),
    'c_index': c_index,
    'c_index_ci': (c_index_ci_lower, c_index_ci_upper),
    'auroc_quintile': auroc_quintile,
    'auroc_quintile_ci': (auroc_q_ci_lower, auroc_q_ci_upper),
    'auroc_quartile': auroc_quartile,
    'auroc_quartile_ci': (auroc_qt_ci_lower, auroc_qt_ci_upper),
    'quintile_threshold': quintile_threshold,
    'quartile_threshold': quartile_threshold,
    'median_by_risk': {
        'Low': df_24h[df_24h['risk_category']=='Low']['diuretic_efficiency_24h'].median(),
        'Moderate': df_24h[df_24h['risk_category']=='Moderate']['diuretic_efficiency_24h'].median(),
        'High': df_24h[df_24h['risk_category']=='High']['diuretic_efficiency_24h'].median()
    }
}

print("\n✓ 24h efficiency results stored in 'results_24h' dictionary")
print("→ Next: Co-Primary Outcome 2 - 72h Diuretic Efficiency")

In [None]:
#==========================================================================
# SECTION 5: CO-PRIMARY OUTCOME 2 - 72-HOUR DIURETIC EFFICIENCY
# Cell 10: Correlation, C-index Analysis
#==========================================================================
# PURPOSE: Evaluate BAN-ADHF score's ability to predict 72h diuretic efficiency
# POPULATION: N=781 (ICU ≥72h with IV diuretic dose >0)
# NOTE: This matches the ORIGINAL BAN-ADHF derivation endpoint (Segar 2024)
# METRICS: Spearman correlation, Pearson (for comparison), C-index
# OUTPUT: Secondary discrimination statistics with 95% CIs
#==========================================================================

from scipy.stats import spearmanr, pearsonr, kruskal
from sklearn.metrics import roc_auc_score
from lifelines.utils import concordance_index
import numpy as np

print("="*70)
print("CO-PRIMARY OUTCOME 2: 72-HOUR DIURETIC EFFICIENCY")
print("="*70)
print("\nNOTE: 72-hour diuretic efficiency was the ORIGINAL endpoint used")
print("in the BAN-ADHF derivation study (Segar 2024).")

#------------------------------------------------------------------------------
# ANALYSIS COHORT
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("ANALYSIS COHORT")
print("-"*70)

df_72h = df[(df['icu_stay_ge_72h'] == 1) &
            (df['diuretic_efficiency_72h'].notna()) &
            (df['diuretic_efficiency_72h'] > 0)].copy()

print(f"Population: ICU ≥72h with IV diuretic dose >0 in first 72 hours")
print(f"N = {len(df_72h):,}")
print(f"\nDiuretic Efficiency Distribution:")
print(f"  Mean ± SD:    {df_72h['diuretic_efficiency_72h'].mean():.1f} ± {df_72h['diuretic_efficiency_72h'].std():.1f} mL/mg")
print(f"  Median (IQR): {df_72h['diuretic_efficiency_72h'].median():.1f} ({df_72h['diuretic_efficiency_72h'].quantile(0.25):.1f}-{df_72h['diuretic_efficiency_72h'].quantile(0.75):.1f}) mL/mg")
print(f"  Range:        {df_72h['diuretic_efficiency_72h'].min():.1f} - {df_72h['diuretic_efficiency_72h'].max():.1f} mL/mg")

skewness = df_72h['diuretic_efficiency_72h'].skew()
print(f"  Skewness:     {skewness:.2f} ({'right-skewed' if skewness > 0.5 else 'approximately normal'})")

#------------------------------------------------------------------------------
# BOOTSTRAP SETUP
#------------------------------------------------------------------------------
np.random.seed(42)
n_bootstrap = 1000

#------------------------------------------------------------------------------
# 1. SPEARMAN CORRELATION (PRIMARY)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("1. SPEARMAN CORRELATION (PRIMARY METRIC)")
print("="*70)

rho_72, rho_72_p = spearmanr(df_72h['ban_adhf_total_score'],
                              df_72h['diuretic_efficiency_72h'])

# Bootstrap 95% CI
rho_72_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_72h), size=len(df_72h), replace=True)
    sample = df_72h.iloc[idx]
    r, _ = spearmanr(sample['ban_adhf_total_score'], sample['diuretic_efficiency_72h'])
    rho_72_bootstrap.append(r)

rho_72_ci_lower = np.percentile(rho_72_bootstrap, 2.5)
rho_72_ci_upper = np.percentile(rho_72_bootstrap, 97.5)

print(f"\nSpearman ρ = {rho_72:.3f} (95% CI: {rho_72_ci_lower:.3f} to {rho_72_ci_upper:.3f})")
print(f"P-value: < 0.001" if rho_72_p < 0.001 else f"P-value: {rho_72_p:.4f}")

print(f"\nInterpretation:")
if abs(rho_72) > 0.5:
    print(f"  Strong inverse correlation (|ρ| > 0.5)")
elif abs(rho_72) > 0.3:
    print(f"  Moderate inverse correlation (|ρ| 0.3-0.5)")
else:
    print(f"  Weak inverse correlation (|ρ| < 0.3)")

#------------------------------------------------------------------------------
# 1b. PEARSON CORRELATION (FOR COMPARISON)
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("1b. PEARSON CORRELATION")
print("-"*70)

pearson_72, pearson_72_p = pearsonr(df_72h['ban_adhf_total_score'],
                                     df_72h['diuretic_efficiency_72h'])

# Bootstrap CI
pearson_72_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_72h), size=len(df_72h), replace=True)
    sample = df_72h.iloc[idx]
    r, _ = pearsonr(sample['ban_adhf_total_score'], sample['diuretic_efficiency_72h'])
    pearson_72_bootstrap.append(r)

pearson_72_ci_lower = np.percentile(pearson_72_bootstrap, 2.5)
pearson_72_ci_upper = np.percentile(pearson_72_bootstrap, 97.5)

print(f"\nPearson r = {pearson_72:.3f} (95% CI: {pearson_72_ci_lower:.3f} to {pearson_72_ci_upper:.3f})")

#------------------------------------------------------------------------------
# 2. C-INDEX (HARRELL'S CONCORDANCE INDEX)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("2. C-INDEX (HARRELL'S CONCORDANCE INDEX)")
print("="*70)

c_index_72 = concordance_index(
    df_72h['diuretic_efficiency_72h'],
    -df_72h['ban_adhf_total_score'],
    np.ones(len(df_72h))
)

# Bootstrap 95% CI
c_index_72_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_72h), size=len(df_72h), replace=True)
    sample = df_72h.iloc[idx]
    c = concordance_index(
        sample['diuretic_efficiency_72h'],
        -sample['ban_adhf_total_score'],
        np.ones(len(sample))
    )
    c_index_72_bootstrap.append(c)

c_index_72_ci_lower = np.percentile(c_index_72_bootstrap, 2.5)
c_index_72_ci_upper = np.percentile(c_index_72_bootstrap, 97.5)

print(f"\nC-index = {c_index_72:.3f} (95% CI: {c_index_72_ci_lower:.3f} to {c_index_72_ci_upper:.3f})")

print(f"\nInterpretation:")
if c_index_72 >= 0.7:
    print(f"  Good discrimination (C-index ≥ 0.7)")
elif c_index_72 >= 0.6:
    print(f"  Moderate discrimination (C-index 0.6-0.7)")
else:
    print(f"  Poor discrimination (C-index < 0.6)")

#------------------------------------------------------------------------------
# 3. AUROC FOR LOWEST QUINTILE (Bottom 20%)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("3. AUROC FOR LOWEST QUINTILE OF 72H EFFICIENCY")
print("="*70)

quintile_72_threshold = df_72h['diuretic_efficiency_72h'].quantile(0.20)
df_72h['lowest_quintile_72'] = (df_72h['diuretic_efficiency_72h'] <= quintile_72_threshold).astype(int)

n_quintile_72 = df_72h['lowest_quintile_72'].sum()
print(f"Threshold: ≤{quintile_72_threshold:.1f} mL/mg (20th percentile)")
print(f"Patients in lowest quintile: {n_quintile_72} ({100*n_quintile_72/len(df_72h):.1f}%)")

auroc_quintile_72 = roc_auc_score(df_72h['lowest_quintile_72'], df_72h['ban_adhf_total_score'])

# Bootstrap 95% CI
auroc_q72_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_72h), size=len(df_72h), replace=True)
    sample = df_72h.iloc[idx]
    if sample['lowest_quintile_72'].sum() > 0 and sample['lowest_quintile_72'].sum() < len(sample):
        auc = roc_auc_score(sample['lowest_quintile_72'], sample['ban_adhf_total_score'])
        auroc_q72_bootstrap.append(auc)

auroc_q72_ci_lower = np.percentile(auroc_q72_bootstrap, 2.5)
auroc_q72_ci_upper = np.percentile(auroc_q72_bootstrap, 97.5)

print(f"\nAUROC = {auroc_quintile_72:.3f} (95% CI: {auroc_q72_ci_lower:.3f} to {auroc_q72_ci_upper:.3f})")

#------------------------------------------------------------------------------
# 4. AUROC FOR LOWEST QUARTILE (Bottom 25%)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("4. AUROC FOR LOWEST QUARTILE OF 72H EFFICIENCY")
print("="*70)

quartile_72_threshold = df_72h['diuretic_efficiency_72h'].quantile(0.25)
df_72h['lowest_quartile_72'] = (df_72h['diuretic_efficiency_72h'] <= quartile_72_threshold).astype(int)

n_quartile_72 = df_72h['lowest_quartile_72'].sum()
print(f"Threshold: ≤{quartile_72_threshold:.1f} mL/mg (25th percentile)")
print(f"Patients in lowest quartile: {n_quartile_72} ({100*n_quartile_72/len(df_72h):.1f}%)")

auroc_quartile_72 = roc_auc_score(df_72h['lowest_quartile_72'], df_72h['ban_adhf_total_score'])

# Bootstrap 95% CI
auroc_qt72_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_72h), size=len(df_72h), replace=True)
    sample = df_72h.iloc[idx]
    if sample['lowest_quartile_72'].sum() > 0 and sample['lowest_quartile_72'].sum() < len(sample):
        auc = roc_auc_score(sample['lowest_quartile_72'], sample['ban_adhf_total_score'])
        auroc_qt72_bootstrap.append(auc)

auroc_qt72_ci_lower = np.percentile(auroc_qt72_bootstrap, 2.5)
auroc_qt72_ci_upper = np.percentile(auroc_qt72_bootstrap, 97.5)

print(f"\nAUROC = {auroc_quartile_72:.3f} (95% CI: {auroc_qt72_ci_lower:.3f} to {auroc_qt72_ci_upper:.3f})")

#------------------------------------------------------------------------------
# 5. EFFICIENCY BY RISK CATEGORY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("5. 72H DIURETIC EFFICIENCY BY RISK CATEGORY")
print("="*70)

print(f"\n{'Risk Category':<15} {'N':<8} {'Median (IQR)':<25} {'Mean ± SD':<20}")
print("-"*70)

median_72_by_risk = {}
for cat in ['Low', 'Moderate', 'High']:
    subset = df_72h[df_72h['risk_category'] == cat]['diuretic_efficiency_72h']
    n = len(subset)
    median = subset.median()
    q1 = subset.quantile(0.25)
    q3 = subset.quantile(0.75)
    mean = subset.mean()
    std = subset.std()
    median_72_by_risk[cat] = median
    print(f"{cat:<15} {n:<8} {median:.1f} ({q1:.1f}-{q3:.1f}){'':>8} {mean:.1f} ± {std:.1f}")

# Kruskal-Wallis test
groups_72 = [df_72h[df_72h['risk_category'] == cat]['diuretic_efficiency_72h'].values
             for cat in ['Low', 'Moderate', 'High']]
kw_stat_72, kw_p_72 = kruskal(*groups_72)
print(f"\nKruskal-Wallis H = {kw_stat_72:.1f}, p < 0.001" if kw_p_72 < 0.001 else f"\nKruskal-Wallis H = {kw_stat_72:.1f}, p = {kw_p_72:.3f}")

#------------------------------------------------------------------------------
# COMPARISON: 24H vs 72H EFFICIENCY (Using stored results)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("COMPARISON: 24-HOUR vs 72-HOUR DIURETIC EFFICIENCY")
print("="*70)

# Check if results_24h exists, otherwise use placeholders
if 'results_24h' in dir():
    r24 = results_24h
    print(f"""
Metric                      24-Hour (N={r24['n']:,})       72-Hour (N={len(df_72h):,})
--------------------------------------------------------------------------------
Spearman ρ                  {r24['spearman_rho']:.3f}                  {rho_72:.3f}
Pearson r                   {r24['pearson_r']:.3f}                  {pearson_72:.3f}
C-index                     {r24['c_index']:.3f}                  {c_index_72:.3f}
AUROC (quintile)            {r24['auroc_quintile']:.3f}                  {auroc_quintile_72:.3f}
AUROC (quartile)            {r24['auroc_quartile']:.3f}                  {auroc_quartile_72:.3f}

Efficiency by Risk:
  Low risk median           {r24['median_by_risk']['Low']:.1f} mL/mg              {median_72_by_risk['Low']:.1f} mL/mg
  Moderate median           {r24['median_by_risk']['Moderate']:.1f} mL/mg              {median_72_by_risk['Moderate']:.1f} mL/mg
  High risk median          {r24['median_by_risk']['High']:.1f} mL/mg              {median_72_by_risk['High']:.1f} mL/mg
""")
else:
    print("\n⚠ results_24h not found - run Cell 9 storage snippet first")
    print("Showing 72h results only:")

#------------------------------------------------------------------------------
# STORE 72H RESULTS
#------------------------------------------------------------------------------
results_72h = {
    'n': len(df_72h),
    'spearman_rho': rho_72,
    'spearman_ci': (rho_72_ci_lower, rho_72_ci_upper),
    'pearson_r': pearson_72,
    'pearson_ci': (pearson_72_ci_lower, pearson_72_ci_upper),
    'c_index': c_index_72,
    'c_index_ci': (c_index_72_ci_lower, c_index_72_ci_upper),
    'auroc_quintile': auroc_quintile_72,
    'auroc_quintile_ci': (auroc_q72_ci_lower, auroc_q72_ci_upper),
    'auroc_quartile': auroc_quartile_72,
    'auroc_quartile_ci': (auroc_qt72_ci_lower, auroc_qt72_ci_upper),
    'quintile_threshold': quintile_72_threshold,
    'quartile_threshold': quartile_72_threshold,
    'median_by_risk': median_72_by_risk
}

#------------------------------------------------------------------------------
# SUMMARY TABLE
#------------------------------------------------------------------------------
print("="*70)
print("SUMMARY: 72-HOUR DIURETIC EFFICIENCY DISCRIMINATION")
print("="*70)

print(f"""
Metric                              Value           95% CI
--------------------------------------------------------------------------------
Spearman ρ                          {rho_72:.3f}          ({rho_72_ci_lower:.3f} to {rho_72_ci_upper:.3f})
Pearson r                           {pearson_72:.3f}          ({pearson_72_ci_lower:.3f} to {pearson_72_ci_upper:.3f})
C-index                             {c_index_72:.3f}          ({c_index_72_ci_lower:.3f} to {c_index_72_ci_upper:.3f})
AUROC (lowest quintile)             {auroc_quintile_72:.3f}          ({auroc_q72_ci_lower:.3f} to {auroc_q72_ci_upper:.3f})
AUROC (lowest quartile)             {auroc_quartile_72:.3f}          ({auroc_qt72_ci_lower:.3f} to {auroc_qt72_ci_upper:.3f})
""")

# Determine comparison
if 'results_24h' in dir():
    rho_diff = rho_72 - r24['spearman_rho']
    comparison = 'comparable' if abs(rho_diff) < 0.03 else ('weaker' if rho_diff > 0 else 'stronger')
    print(f"Key Finding: 72-hour efficiency (original derivation endpoint) shows")
    print(f"{comparison} discrimination compared to 24-hour efficiency (Δρ = {rho_diff:+.3f})")

print("\n✓ 72h efficiency results stored in 'results_72h' dictionary")
print("\n→ Next: Diuretic Resistance Analysis (Binary Outcome)")

In [None]:
#==========================================================================
# SECTION 6: DIURETIC RESISTANCE ANALYSIS
# Cell 11: Binary Diuretic Resistance (Urine Output ≤3,000 mL)
#==========================================================================
# PURPOSE: Evaluate BAN-ADHF score's ability to predict binary DR
# POPULATION: N=1,382 (ICU ≥24h)
# DEFINITION: Urine output ≤3,000 mL in first 24 hours
# METRICS: AUROC, sensitivity/specificity, DR rates by risk category
# OUTPUT: Binary outcome discrimination statistics
#==========================================================================

from scipy.stats import chi2_contingency
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix
import numpy as np

print("="*70)
print("DIURETIC RESISTANCE ANALYSIS (BINARY OUTCOME)")
print("="*70)
print("\nDefinition: Urine output ≤3,000 mL in first 24 hours of ICU stay")
print("This is a SECONDARY outcome (BAN-ADHF was designed for efficiency,")
print("not binary resistance).")

#------------------------------------------------------------------------------
# ANALYSIS COHORT
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("ANALYSIS COHORT")
print("-"*70)

# ICU ≥24h cohort for DR analysis
df_dr = df[df['icu_stay_ge_24h'] == 1].copy()

print(f"Population: ICU stay ≥24 hours")
print(f"N = {len(df_dr):,}")

# Verify DR definition
df_dr['diuretic_resistance'] = pd.to_numeric(df_dr['diuretic_resistance'], errors='coerce')
n_dr = df_dr['diuretic_resistance'].sum()
dr_prevalence = 100 * n_dr / len(df_dr)

print(f"\nDiuretic Resistance:")
print(f"  Resistant (≤3L):     {int(n_dr):,} ({dr_prevalence:.1f}%)")
print(f"  Not resistant (>3L): {len(df_dr) - int(n_dr):,} ({100-dr_prevalence:.1f}%)")

print(f"\nNote: DR prevalence ({dr_prevalence:.1f}%) is much higher than derivation")
print(f"cohorts (~25%) due to ICU population severity.")

#------------------------------------------------------------------------------
# BOOTSTRAP SETUP
#------------------------------------------------------------------------------
np.random.seed(42)
n_bootstrap = 1000

#------------------------------------------------------------------------------
# 1. AUROC FOR DIURETIC RESISTANCE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("1. AUROC FOR DIURETIC RESISTANCE PREDICTION")
print("="*70)

auroc_dr = roc_auc_score(df_dr['diuretic_resistance'], df_dr['ban_adhf_total_score'])

# Bootstrap 95% CI
auroc_dr_bootstrap = []
for _ in range(n_bootstrap):
    idx = np.random.choice(len(df_dr), size=len(df_dr), replace=True)
    sample = df_dr.iloc[idx]
    if sample['diuretic_resistance'].sum() > 0 and sample['diuretic_resistance'].sum() < len(sample):
        auc = roc_auc_score(sample['diuretic_resistance'], sample['ban_adhf_total_score'])
        auroc_dr_bootstrap.append(auc)

auroc_dr_ci_lower = np.percentile(auroc_dr_bootstrap, 2.5)
auroc_dr_ci_upper = np.percentile(auroc_dr_bootstrap, 97.5)

print(f"\nAUROC = {auroc_dr:.3f} (95% CI: {auroc_dr_ci_lower:.3f} to {auroc_dr_ci_upper:.3f})")

print(f"\nInterpretation:")
if auroc_dr >= 0.7:
    print(f"  Good discrimination (AUROC ≥ 0.7)")
elif auroc_dr >= 0.6:
    print(f"  Moderate discrimination (AUROC 0.6-0.7)")
else:
    print(f"  Poor discrimination (AUROC < 0.6)")

#------------------------------------------------------------------------------
# 2. OPTIMAL CUTOFF (YOUDEN'S INDEX)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("2. OPTIMAL CUTOFF FOR DIURETIC RESISTANCE (YOUDEN'S INDEX)")
print("="*70)

fpr, tpr, thresholds = roc_curve(df_dr['diuretic_resistance'],
                                  df_dr['ban_adhf_total_score'])

# Youden's J = Sensitivity + Specificity - 1
youden_j = tpr - fpr
optimal_idx = np.argmax(youden_j)
optimal_cutoff = thresholds[optimal_idx]
optimal_sensitivity = tpr[optimal_idx]
optimal_specificity = 1 - fpr[optimal_idx]

print(f"\nOptimal cutoff: ≥{optimal_cutoff:.0f}")
print(f"Sensitivity: {optimal_sensitivity:.3f} ({optimal_sensitivity*100:.1f}%)")
print(f"Specificity: {optimal_specificity:.3f} ({optimal_specificity*100:.1f}%)")
print(f"Youden's J: {youden_j[optimal_idx]:.3f}")

# Compare to our high-risk threshold (≥13)
idx_13 = np.argmin(np.abs(thresholds - 13))
sens_13 = tpr[idx_13]
spec_13 = 1 - fpr[idx_13]
print(f"\nAt our high-risk threshold (≥13):")
print(f"  Sensitivity: {sens_13:.3f} ({sens_13*100:.1f}%)")
print(f"  Specificity: {spec_13:.3f} ({spec_13*100:.1f}%)")

#------------------------------------------------------------------------------
# 3. DIURETIC RESISTANCE BY RISK CATEGORY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("3. DIURETIC RESISTANCE BY RISK CATEGORY")
print("="*70)

print(f"\n{'Risk Category':<15} {'N':<8} {'DR n (%)':<20} {'No DR n (%)':<20}")
print("-"*65)

dr_by_risk = {}
for cat in ['Low', 'Moderate', 'High']:
    subset = df_dr[df_dr['risk_category'] == cat]
    n = len(subset)
    n_dr_cat = int(subset['diuretic_resistance'].sum())
    dr_rate = 100 * n_dr_cat / n if n > 0 else 0
    dr_by_risk[cat] = {'n': n, 'n_dr': n_dr_cat, 'rate': dr_rate}
    print(f"{cat:<15} {n:<8} {n_dr_cat} ({dr_rate:.1f}%){'':>8} {n - n_dr_cat} ({100-dr_rate:.1f}%)")

# Chi-square test
contingency = pd.crosstab(df_dr['risk_category'], df_dr['diuretic_resistance'])
chi2, p_chi, dof, expected = chi2_contingency(contingency)
print(f"\nChi-square = {chi2:.1f}, p < 0.001" if p_chi < 0.001 else f"\nChi-square = {chi2:.1f}, p = {p_chi:.3f}")

# Risk ratios
low_rate = dr_by_risk['Low']['rate']
high_rate = dr_by_risk['High']['rate']
if low_rate > 0:
    rr_high_vs_low = high_rate / low_rate
    print(f"\nRisk ratio (High vs Low): {rr_high_vs_low:.2f}")
    print(f"  High-risk patients are {rr_high_vs_low:.1f}x more likely to have DR")

#------------------------------------------------------------------------------
# 4. SENSITIVITY ANALYSIS: DR THRESHOLD VARIATIONS
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("4. SENSITIVITY ANALYSIS: ALTERNATIVE DR DEFINITIONS")
print("="*70)

# Test different urine output thresholds
print("\nAUROC by different DR threshold definitions:")
print(f"{'Threshold':<20} {'N with DR':<15} {'Prevalence':<15} {'AUROC':<15}")
print("-"*65)

for threshold in [2000, 2500, 3000, 3500, 4000]:
    dr_alt = (df_dr['urine_output_24h_ml'] <= threshold).astype(int)
    n_dr_alt = dr_alt.sum()
    prev_alt = 100 * n_dr_alt / len(df_dr)
    if n_dr_alt > 0 and n_dr_alt < len(df_dr):
        auroc_alt = roc_auc_score(dr_alt, df_dr['ban_adhf_total_score'])
        print(f"≤{threshold} mL{'':<12} {n_dr_alt:<15} {prev_alt:.1f}%{'':>8} {auroc_alt:.3f}")

#------------------------------------------------------------------------------
# 5. COMPARISON WITH EFFICIENCY METRICS
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("5. COMPARISON: BINARY DR vs CONTINUOUS EFFICIENCY")
print("="*70)

if 'results_24h' in dir():
    print(f"""
Outcome                         AUROC           Notes
--------------------------------------------------------------------------------
24h Efficiency (quintile)       {results_24h['auroc_quintile']:.3f}          Bottom 20% of efficiency
24h Efficiency (quartile)       {results_24h['auroc_quartile']:.3f}          Bottom 25% of efficiency
Binary DR (≤3L)                 {auroc_dr:.3f}          Urine ≤3,000 mL

Key Observation:
Binary DR shows {'comparable' if abs(auroc_dr - results_24h['auroc_quintile']) < 0.05 else 'lower'}
discrimination compared to efficiency quintile/quartile outcomes.
This is expected as binary outcomes lose information compared to continuous.
""")
else:
    print(f"\nBinary DR AUROC: {auroc_dr:.3f}")

#------------------------------------------------------------------------------
# STORE RESULTS
#------------------------------------------------------------------------------
results_dr = {
    'n': len(df_dr),
    'n_dr': int(n_dr),
    'prevalence': dr_prevalence,
    'auroc': auroc_dr,
    'auroc_ci': (auroc_dr_ci_lower, auroc_dr_ci_upper),
    'optimal_cutoff': optimal_cutoff,
    'optimal_sensitivity': optimal_sensitivity,
    'optimal_specificity': optimal_specificity,
    'dr_by_risk': dr_by_risk
}

#------------------------------------------------------------------------------
# SUMMARY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUMMARY: DIURETIC RESISTANCE DISCRIMINATION")
print("="*70)

print(f"""
Population: ICU ≥24h (N = {len(df_dr):,})
DR Definition: Urine output ≤3,000 mL in first 24 hours
DR Prevalence: {dr_prevalence:.1f}% (higher than derivation cohorts due to ICU severity)

Discrimination:
  AUROC = {auroc_dr:.3f} (95% CI: {auroc_dr_ci_lower:.3f} to {auroc_dr_ci_upper:.3f})

Optimal Cutoff (Youden):
  Score ≥{optimal_cutoff:.0f}
  Sensitivity: {optimal_sensitivity*100:.1f}%
  Specificity: {optimal_specificity*100:.1f}%

DR Rates by Risk Category:
  Low risk (≤7):      {dr_by_risk['Low']['rate']:.1f}%
  Moderate (8-12):    {dr_by_risk['Moderate']['rate']:.1f}%
  High risk (≥13):    {dr_by_risk['High']['rate']:.1f}%

Risk Ratio (High vs Low): {rr_high_vs_low:.2f}x
""")

print("✓ DR results stored in 'results_dr' dictionary")
print("\n→ Next: Co-Primary Outcome 3 - In-Hospital Mortality (Exploratory)")

In [None]:
# Make a SAFE AUROC summary for 24h efficiency (so it can't be overwritten)
results_24h_auc = {
    "n": int(len(df_24h)),
    "auroc_q20": float(auroc_quintile),
    "auroc_q25": float(auroc_quartile),
    "auroc_q20_ci": (float(auroc_q_ci_lower), float(auroc_q_ci_upper)),
    "auroc_q25_ci": (float(auroc_qt_ci_lower), float(auroc_qt_ci_upper)),
}

print("✓ results_24h_auc created")
print(results_24h_auc)


In [None]:
#==========================================================================
# SECTION 7: CO-PRIMARY OUTCOME 3 - IN-HOSPITAL MORTALITY (EXPLORATORY)
# Cell 12: Mortality Discrimination Analysis (CORRECTED)
#==========================================================================
# PURPOSE: Exploratory assessment of BAN-ADHF score association with in-hospital
#          mortality (note: score developed for diuretic efficiency, not mortality)
# POPULATION: N = 1,505 (Full cohort)
# IMPORTANT: This mortality analysis is EXPLORATORY and represents an extension
#            of the score to a different outcome domain.
# METRICS:
#   - Discrimination: AUROC with 95% CI by bootstrap resampling (1,000 iterations)
#   - Secondary discrimination: AUPRC with 95% CI by bootstrap resampling
#   - Exploratory cutoff: Youden's index and comparison to prespecified threshold (≥13)
#   - Risk stratification: Mortality rates by predefined risk category + chi-square test
#   - Subgroups: AUROC (bootstrap CI) in cardiogenic shock vs no cardiogenic shock
#   - Comparison table: uses results_24h_auc (Q20/Q25) + results_dr (binary DR) if present
# OUTPUT: Exploratory mortality discrimination statistics
#==========================================================================

from scipy.stats import chi2_contingency
from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score
import numpy as np
import pandas as pd

print("="*70)
print("CO-PRIMARY OUTCOME 3: IN-HOSPITAL MORTALITY (EXPLORATORY)")
print("="*70)
print("\n" + "*"*70)
print("IMPORTANT: BAN-ADHF was designed for DIURETIC EFFICIENCY prediction,")
print("NOT mortality. This mortality analysis is EXPLORATORY and represents")
print("a novel extension of the score to a different outcome domain.")
print("*"*70)

#------------------------------------------------------------------------------
# ANALYSIS COHORT
#------------------------------------------------------------------------------
print("\n" + "-"*70)
print("ANALYSIS COHORT")
print("-"*70)

df_mort = df.copy()

n_total = len(df_mort)
n_deaths = int(df_mort['hospital_expire_flag'].sum())
mortality_rate = 100 * n_deaths / n_total

print("Population: Full ICU ADHF cohort")
print(f"N = {n_total:,}")
print("\nIn-Hospital Mortality:")
print(f"  Deaths:    {n_deaths} ({mortality_rate:.1f}%)")
print(f"  Survivors: {n_total - n_deaths} ({100-mortality_rate:.1f}%)")

#------------------------------------------------------------------------------
# BOOTSTRAP SETUP
#------------------------------------------------------------------------------
np.random.seed(42)
n_bootstrap = 1000

def bootstrap_metric_ci(df_in, y_col, s_col, metric_fn, n_boot=1000):
    """
    Bootstrap 95% CI for a metric that requires both classes present.
    Returns: (point_estimate, ci_low, ci_high, n_valid_boot)
    """
    y = df_in[y_col].values
    s = df_in[s_col].values

    point = metric_fn(y, s)

    boot_vals = []
    for _ in range(n_boot):
        idx = np.random.choice(len(df_in), size=len(df_in), replace=True)
        sample = df_in.iloc[idx]
        if sample[y_col].nunique() < 2:
            continue
        boot_vals.append(metric_fn(sample[y_col].values, sample[s_col].values))

    if len(boot_vals) < 50:
        raise ValueError(
            f"Too few valid bootstrap resamples ({len(boot_vals)}/{n_boot}). "
            "Event rate may be too low for stable bootstrap CI."
        )

    ci_low, ci_high = np.percentile(boot_vals, [2.5, 97.5])
    return point, ci_low, ci_high, len(boot_vals)

#------------------------------------------------------------------------------
# 1. DISCRIMINATION FOR MORTALITY: AUROC + AUPRC (EXPLORATORY)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("1. DISCRIMINATION FOR IN-HOSPITAL MORTALITY (EXPLORATORY)")
print("="*70)

# AUROC with bootstrap CI
auroc_mort, auroc_mort_ci_lower, auroc_mort_ci_upper, n_valid_auc = bootstrap_metric_ci(
    df_mort,
    y_col='hospital_expire_flag',
    s_col='ban_adhf_total_score',
    metric_fn=roc_auc_score,
    n_boot=n_bootstrap
)

print(f"\nAUROC = {auroc_mort:.3f} (95% CI: {auroc_mort_ci_lower:.3f} to {auroc_mort_ci_upper:.3f})")
print(f"  Bootstrap resamples used: {n_valid_auc}/{n_bootstrap}")

print("\nInterpretation:")
if auroc_mort >= 0.7:
    print("  Moderate-good discrimination (AUROC ≥ 0.7)")
elif auroc_mort >= 0.6:
    print("  Poor-moderate discrimination (AUROC 0.6-0.7)")
else:
    print("  Poor discrimination (AUROC < 0.6)")

print("\nContext: This is limited for a score")
print("not designed for mortality prediction.")

# AUPRC with bootstrap CI (helpful when outcome is imbalanced)
auprc_mort, auprc_ci_lower, auprc_ci_upper, n_valid_pr = bootstrap_metric_ci(
    df_mort,
    y_col='hospital_expire_flag',
    s_col='ban_adhf_total_score',
    metric_fn=average_precision_score,
    n_boot=n_bootstrap
)

print(f"\nAUPRC = {auprc_mort:.3f} (95% CI: {auprc_ci_lower:.3f} to {auprc_ci_upper:.3f})")
print(f"  Bootstrap resamples used: {n_valid_pr}/{n_bootstrap}")

#------------------------------------------------------------------------------
# 2. EXPLORATORY CUTOFF (YOUDEN'S INDEX) + COMPARISON TO PRE-SPECIFIED THRESHOLD
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("2. EXPLORATORY CUTOFF FOR MORTALITY (YOUDEN'S INDEX)")
print("="*70)

fpr_m, tpr_m, thresholds_m = roc_curve(
    df_mort['hospital_expire_flag'],
    df_mort['ban_adhf_total_score']
)

youden_j_m = tpr_m - fpr_m
optimal_idx_m = int(np.argmax(youden_j_m))
optimal_cutoff_m = thresholds_m[optimal_idx_m]
optimal_sens_m = tpr_m[optimal_idx_m]
optimal_spec_m = 1 - fpr_m[optimal_idx_m]

print("\nExploratory (Youden-based) cutoff:")
print(f"  Cutoff: ≥{optimal_cutoff_m:.0f}")
print(f"  Sensitivity: {optimal_sens_m:.3f} ({optimal_sens_m*100:.1f}%)")
print(f"  Specificity: {optimal_spec_m:.3f} ({optimal_spec_m*100:.1f}%)")
print(f"  Youden's J: {youden_j_m[optimal_idx_m]:.3f}")

# Comparison to prespecified high-risk threshold (≥13)
idx_13_m = int(np.argmin(np.abs(thresholds_m - 13)))
sens_13_m = tpr_m[idx_13_m]
spec_13_m = 1 - fpr_m[idx_13_m]

print("\nComparison to prespecified high-risk threshold (≥13):")
print(f"  Sensitivity: {sens_13_m:.3f} ({sens_13_m*100:.1f}%)")
print(f"  Specificity: {spec_13_m:.3f} ({spec_13_m*100:.1f}%)")

#------------------------------------------------------------------------------
# 3. MORTALITY BY PREDEFINED RISK CATEGORY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("3. MORTALITY BY RISK CATEGORY")
print("="*70)

print(f"\n{'Risk Category':<15} {'N':<8} {'Deaths n (%)':<20} {'Survivors n (%)':<20}")
print("-"*65)

mort_by_risk = {}
for cat in ['Low', 'Moderate', 'High']:
    subset = df_mort[df_mort['risk_category'] == cat]
    n = len(subset)
    n_deaths_cat = int(subset['hospital_expire_flag'].sum())
    mort_rate = 100 * n_deaths_cat / n if n > 0 else 0.0
    mort_by_risk[cat] = {'n': n, 'n_deaths': n_deaths_cat, 'rate': mort_rate}
    print(f"{cat:<15} {n:<8} {n_deaths_cat} ({mort_rate:.1f}%){'':>9} {n - n_deaths_cat} ({100-mort_rate:.1f}%)")

# Chi-square test
contingency_m = pd.crosstab(df_mort['risk_category'], df_mort['hospital_expire_flag'])
chi2_m, p_chi_m, dof_m, expected_m = chi2_contingency(contingency_m)
print("\n" + (f"Chi-square = {chi2_m:.1f}, p < 0.001" if p_chi_m < 0.001 else f"Chi-square = {chi2_m:.1f}, p = {p_chi_m:.3f}"))

# Risk ratio + absolute risk increase
low_mort = mort_by_risk['Low']['rate']
high_mort = mort_by_risk['High']['rate']

rr_mort = np.nan
if low_mort > 0:
    rr_mort = high_mort / low_mort
    print(f"\nRisk ratio (High vs Low): {rr_mort:.2f}")
    print(f"  High-risk patients have {rr_mort:.1f}x higher mortality")

ari = high_mort - low_mort
print(f"Absolute risk increase: {ari:.1f} percentage points")

#------------------------------------------------------------------------------
# 4. CARDIOGENIC SHOCK SUBGROUP (PRE-SPECIFIED)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("4. MORTALITY DISCRIMINATION IN CARDIOGENIC SHOCK SUBGROUP")
print("="*70)

df_cs = df_mort[df_mort['cardiogenic_shock'] == 1].copy()
df_no_cs = df_mort[df_mort['cardiogenic_shock'] == 0].copy()

print(f"\nCardiogenic Shock (N = {len(df_cs):,}):")
print(f"  Mortality rate: {100*df_cs['hospital_expire_flag'].mean():.1f}%")

if df_cs['hospital_expire_flag'].sum() > 10 and df_cs['hospital_expire_flag'].nunique() == 2:
    auroc_cs, auroc_cs_ci_lower, auroc_cs_ci_upper, n_valid_cs = bootstrap_metric_ci(
        df_cs,
        y_col='hospital_expire_flag',
        s_col='ban_adhf_total_score',
        metric_fn=roc_auc_score,
        n_boot=n_bootstrap
    )
    print(f"  AUROC = {auroc_cs:.3f} (95% CI: {auroc_cs_ci_lower:.3f} to {auroc_cs_ci_upper:.3f})")
    print(f"  Bootstrap resamples used: {n_valid_cs}/{n_bootstrap}")
else:
    auroc_cs = None
    auroc_cs_ci_lower, auroc_cs_ci_upper = (None, None)
    print("  AUROC: Insufficient events for analysis")

print(f"\nNo Cardiogenic Shock (N = {len(df_no_cs):,}):")
print(f"  Mortality rate: {100*df_no_cs['hospital_expire_flag'].mean():.1f}%")

if df_no_cs['hospital_expire_flag'].sum() > 10 and df_no_cs['hospital_expire_flag'].nunique() == 2:
    auroc_no_cs, auroc_no_cs_ci_lower, auroc_no_cs_ci_upper, n_valid_no_cs = bootstrap_metric_ci(
        df_no_cs,
        y_col='hospital_expire_flag',
        s_col='ban_adhf_total_score',
        metric_fn=roc_auc_score,
        n_boot=n_bootstrap
    )
    print(f"  AUROC = {auroc_no_cs:.3f} (95% CI: {auroc_no_cs_ci_lower:.3f} to {auroc_no_cs_ci_upper:.3f})")
    print(f"  Bootstrap resamples used: {n_valid_no_cs}/{n_bootstrap}")
else:
    auroc_no_cs = None
    auroc_no_cs_ci_lower, auroc_no_cs_ci_upper = (None, None)

if auroc_cs is not None and auroc_no_cs is not None:
    auroc_diff = auroc_cs - auroc_no_cs
    print(f"\nDifference (CS - No CS): {auroc_diff:+.3f}")
    print("→ BAN-ADHF shows better mortality discrimination in cardiogenic shock patients" if auroc_diff > 0
          else "→ BAN-ADHF shows worse mortality discrimination in cardiogenic shock patients")

#------------------------------------------------------------------------------
# 5. COMPARISON: MORTALITY vs DIURETIC EFFICIENCY OUTCOMES
#    Uses results_24h_auc (Q20/Q25) and results_dr if they exist.
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("5. COMPARISON: MORTALITY vs DIURETIC EFFICIENCY OUTCOMES")
print("="*70)

has_24h_auc = 'results_24h_auc' in globals() and isinstance(results_24h_auc, dict)
has_dr = 'results_dr' in globals() and isinstance(results_dr, dict)

if has_24h_auc and has_dr and ('auroc' in results_dr):
    print(f"""
Outcome                         AUROC           Population      Purpose
--------------------------------------------------------------------------------
24h Efficiency (Q20)            {results_24h_auc['auroc_q20']:.3f}          N={results_24h_auc['n']:,}        Score's intended use
24h Efficiency (Q25)            {results_24h_auc['auroc_q25']:.3f}          N={results_24h_auc['n']:,}        Score's intended use
Binary DR (≤3L)                 {results_dr['auroc']:.3f}          N={results_dr['n']:,}        Secondary outcome
In-hospital mortality           {auroc_mort:.3f}          N={n_total:,}        EXPLORATORY
  - Cardiogenic shock           {auroc_cs:.3f}          N={len(df_cs):,}        Subgroup
  - No cardiogenic shock        {auroc_no_cs:.3f}          N={len(df_no_cs):,}        Subgroup
""")
else:
    print("Comparison table not shown because results_24h_auc and/or results_dr are missing.")
    if not has_24h_auc:
        print("  - results_24h_auc not found. Run your Step A cell that creates results_24h_auc.")
    if not has_dr or ('auroc' not in results_dr):
        print("  - results_dr not found or missing 'auroc'. Run the DR cell that creates results_dr['auroc'].")

#------------------------------------------------------------------------------
# STORE RESULTS
#------------------------------------------------------------------------------
results_mortality = {
    'n': n_total,
    'n_deaths': n_deaths,
    'mortality_rate': mortality_rate,
    'auroc': auroc_mort,
    'auroc_ci': (auroc_mort_ci_lower, auroc_mort_ci_upper),
    'auprc': auprc_mort,
    'auprc_ci': (auprc_ci_lower, auprc_ci_upper),
    'youden_cutoff': float(optimal_cutoff_m),
    'youden_sensitivity': float(optimal_sens_m),
    'youden_specificity': float(optimal_spec_m),
    'threshold_13_sensitivity': float(sens_13_m),
    'threshold_13_specificity': float(spec_13_m),
    'mort_by_risk': mort_by_risk,
    'chi2': float(chi2_m),
    'p_value': float(p_chi_m),
    'risk_ratio_high_vs_low': float(rr_mort) if not np.isnan(rr_mort) else None,
    'absolute_risk_increase_pp': float(ari),
    'auroc_cs': auroc_cs,
    'auroc_cs_ci': (auroc_cs_ci_lower, auroc_cs_ci_upper),
    'auroc_no_cs': auroc_no_cs,
    'auroc_no_cs_ci': (auroc_no_cs_ci_lower, auroc_no_cs_ci_upper)
}

#------------------------------------------------------------------------------
# SUMMARY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUMMARY: IN-HOSPITAL MORTALITY (EXPLORATORY)")
print("="*70)

print(f"""
Population: Full cohort (N = {n_total:,})
Mortality Rate: {mortality_rate:.1f}%

IMPORTANT CONTEXT:
BAN-ADHF was designed for diuretic efficiency, NOT mortality.
These results represent a novel exploratory extension.

Discrimination (Exploratory):
  AUROC = {auroc_mort:.3f} (95% CI: {auroc_mort_ci_lower:.3f} to {auroc_mort_ci_upper:.3f})
  AUPRC = {auprc_mort:.3f} (95% CI: {auprc_ci_lower:.3f} to {auprc_ci_upper:.3f})

Mortality by Risk Category:
  Low risk (≤7):      {mort_by_risk['Low']['rate']:.1f}% (N={mort_by_risk['Low']['n']})
  Moderate (8-12):    {mort_by_risk['Moderate']['rate']:.1f}% (N={mort_by_risk['Moderate']['n']})
  High risk (≥13):    {mort_by_risk['High']['rate']:.1f}% (N={mort_by_risk['High']['n']})

Risk Ratio (High vs Low): {rr_mort:.2f}x
Absolute Risk Increase: {ari:.1f} percentage points

Subgroup Analysis:
  Cardiogenic shock:     AUROC = {auroc_cs:.3f}
  No cardiogenic shock:  AUROC = {auroc_no_cs:.3f}
""")

print("✓ Mortality results stored in 'results_mortality' dictionary")
print("\n→ Next: Subgroup Analyses")


In [None]:
#==========================================================================
# SECTION 6b: DISCRIMINATION COMPARISON FRAMEWORK
# Cell 12b: Literature Comparison and Prevalence Effect Analysis (FIXED)
#==========================================================================
# PURPOSE: Synthesize discrimination results with literature comparisons
# OUTPUT: Comparison tables and manuscript-ready interpretations
#==========================================================================

print("="*70)
print("DISCRIMINATION COMPARISON FRAMEWORK")
print("="*70)
print("\nThis cell synthesizes results from Cells 9-12 with literature context.")

#------------------------------------------------------------------------------
# SAFETY CHECKS
#------------------------------------------------------------------------------
missing = []
if 'results_24h_auc' not in globals():
    missing.append("results_24h_auc (Q20/Q25 AUROC summary)")
if 'results_dr' not in globals():
    missing.append("results_dr (binary DR results)")
if 'results_mortality' not in globals():
    missing.append("results_mortality (mortality results)")

# results_24h is used only for rho/r/c-index text. If it is missing or overwritten,
# we still run the AUROC tables using results_24h_auc.
has_results_24h_continuous = (
    'results_24h' in globals() and isinstance(results_24h, dict)
    and 'spearman_rho' in results_24h and 'pearson_r' in results_24h
    and 'spearman_ci' in results_24h and 'pearson_ci' in results_24h
)

if missing:
    raise ValueError("Missing required objects: " + ", ".join(missing))

# Convenience handles
r_auc = results_24h_auc
r_dr = results_dr
r_m = results_mortality

# Pull prevalence from results_dr (not hard-coded)
dr_prev = float(r_dr.get('prevalence', np.nan))

#------------------------------------------------------------------------------
# TABLE X: DISCRIMINATION BY OUTCOME DEFINITION
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("TABLE X: DISCRIMINATION BY OUTCOME DEFINITION")
print("="*70)

print(f"""
Outcome                      Definition              Prevalence    AUROC (95% CI)                    Reference
---------------------------------------------------------------------------------------------------------------
Lowest efficiency quintile   Bottom 20%              20.0%         {r_auc['auroc_q20']:.3f} ({r_auc['auroc_q20_ci'][0]:.3f}–{r_auc['auroc_q20_ci'][1]:.3f})        Segar: 0.84
Lowest efficiency quartile   Bottom 25%              25.0%         {r_auc['auroc_q25']:.3f} ({r_auc['auroc_q25_ci'][0]:.3f}–{r_auc['auroc_q25_ci'][1]:.3f})        Pandey: 0.70
Binary DR                    Urine ≤3,000 mL         {dr_prev:.1f}%        {r_dr['auroc']:.3f} ({r_dr['auroc_ci'][0]:.3f}–{r_dr['auroc_ci'][1]:.3f})        Mauch: 0.631
In-hospital mortality        Death during admission  {r_m['mortality_rate']:.1f}%        {r_m['auroc']:.3f} ({r_m['auroc_ci'][0]:.3f}–{r_m['auroc_ci'][1]:.3f})        Exploratory
""")

print("""
Key Observations:
- Continuous efficiency definitions (Q20/Q25) show strong discrimination.
- Binary DR is a coarser definition and loses information compared with continuous efficiency.
- Mortality discrimination is limited, consistent with the score’s original intent.
""")

#------------------------------------------------------------------------------
# LITERATURE COMPARISON TABLE: DIURETIC EFFICIENCY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("LITERATURE COMPARISON: DIURETIC EFFICIENCY PREDICTION")
print("="*70)

if has_results_24h_continuous:
    print(f"""
Study           Population       N       Correlation                AUROC (Quintile)   AUROC (Quartile)
-------------------------------------------------------------------------------------------------------
Segar 2024      Trial derivation 707     —                          0.84               —
Mauch 2025      Floor patients   317     r = −0.40                  —                  —
Pandey 2025     CLOROTIC trial   220     —                          —                  0.70
-------------------------------------------------------------------------------------------------------
This study      ICU patients     {r_auc['n']:,}   ρ = {results_24h['spearman_rho']:.3f} (95% CI {results_24h['spearman_ci'][0]:.3f}–{results_24h['spearman_ci'][1]:.3f})
                                         r = {results_24h['pearson_r']:.3f} (95% CI {results_24h['pearson_ci'][0]:.3f}–{results_24h['pearson_ci'][1]:.3f})     {r_auc['auroc_q20']:.3f}              {r_auc['auroc_q25']:.3f}
""")
else:
    print(f"""
Note: results_24h (continuous correlation summary) not available or has been overwritten.
Showing AUROC comparisons only from results_24h_auc.

This study (ICU patients, N={r_auc['n']:,}):
  AUROC Q20: {r_auc['auroc_q20']:.3f}
  AUROC Q25: {r_auc['auroc_q25']:.3f}
""")

#------------------------------------------------------------------------------
# LITERATURE COMPARISON TABLE: BINARY DIURETIC RESISTANCE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("LITERATURE COMPARISON: BINARY DIURETIC RESISTANCE")
print("="*70)

print(f"""
Study           Population       DR Prevalence    Binary DR AUROC    Note
---------------------------------------------------------------------------
Mauch 2025      Floor patients   ~25%             0.631              Reference
This study      ICU patients     {dr_prev:.1f}%            {r_dr['auroc']:.3f}              Higher prevalence and more severe case-mix
---------------------------------------------------------------------------
Difference (AUROC):                               {r_dr['auroc'] - 0.631:+.3f}
""")

print("""
Interpretation:
Binary DR is a simplified threshold. In an ICU cohort where most patients meet the threshold,
it becomes a less informative endpoint than continuous efficiency.
""")

#------------------------------------------------------------------------------
# COMPLETE COMPARISON FRAMEWORK
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("COMPLETE VALIDATION COMPARISON FRAMEWORK")
print("="*70)

# Deltas for AUROC comparisons
delta_quintile = ((r_auc['auroc_q20'] - 0.84) / 0.84) * 100
delta_quartile = ((r_auc['auroc_q25'] - 0.70) / 0.70) * 100
delta_dr = ((r_dr['auroc'] - 0.631) / 0.631) * 100

if has_results_24h_continuous:
    delta_pearson = ((abs(results_24h['pearson_r']) - 0.40) / 0.40) * 100
    pearson_line = f"{results_24h['pearson_r']:.3f}        Mauch: −0.40     {delta_pearson:+.1f}%       Comparable magnitude"
    spearman_line = f"{results_24h['spearman_rho']:.3f}        —                —           Strong inverse association"
else:
    pearson_line = "NA           Mauch: −0.40     —           Correlation summary not available in this session"
    spearman_line = "NA           —                —           Correlation summary not available in this session"

print(f"""
Metric                  This Study     Comparator       Δ           Interpretation
------------------------------------------------------------------------------------
Spearman ρ              {spearman_line}
Pearson r               {pearson_line}
AUROC quintile (Q20)    {r_auc['auroc_q20']:.3f}        Segar: 0.84      {delta_quintile:+.1f}%       Comparable
AUROC quartile (Q25)    {r_auc['auroc_q25']:.3f}        Pandey: 0.70     {delta_quartile:+.1f}%       Better
Binary DR AUROC         {r_dr['auroc']:.3f}        Mauch: 0.631     {delta_dr:+.1f}%       Lower with binary threshold in ICU case-mix
In-hospital mortality   {r_m['auroc']:.3f}        —                —           Exploratory (not designed)
  - Cardiogenic shock   {r_m['auroc_cs']:.3f}        —                —           Improved discrimination
  - No CS               {r_m['auroc_no_cs']:.3f}        —                —           Limited discrimination
""")

#------------------------------------------------------------------------------
# KEY MANUSCRIPT SENTENCES (shorter, cleaner)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("KEY MANUSCRIPT SENTENCES")
print("="*70)

if has_results_24h_continuous:
    eff_sentence = (
        f"The BAN-ADHF score was inversely associated with 24-hour diuretic efficiency "
        f"(Spearman ρ = {results_24h['spearman_rho']:.3f}, 95% CI {results_24h['spearman_ci'][0]:.3f}–{results_24h['spearman_ci'][1]:.3f}; p<0.001). "
        f"For comparison with prior work, Pearson r was {results_24h['pearson_r']:.3f}."
    )
else:
    eff_sentence = (
        "The BAN-ADHF score demonstrated strong discrimination for low 24-hour efficiency using percentile-based definitions."
    )

print(f"""
RESULTS. Diuretic efficiency:
{eff_sentence}
Discrimination for low efficiency was strong (AUROC Q20 {r_auc['auroc_q20']:.3f}; AUROC Q25 {r_auc['auroc_q25']:.3f}).

RESULTS. Binary diuretic resistance:
Binary diuretic resistance (≤3,000 mL) showed modest discrimination (AUROC {r_dr['auroc']:.3f}).
This likely reflects information loss from dichotomizing a continuous phenomenon and the more severe ICU case-mix.

RESULTS. Mortality (exploratory):
In exploratory analysis, BAN-ADHF showed limited discrimination for in-hospital mortality (AUROC {r_m['auroc']:.3f}).
Discrimination was higher in cardiogenic shock (AUROC {r_m['auroc_cs']:.3f}) than in non-shock patients (AUROC {r_m['auroc_no_cs']:.3f}).
""")

#------------------------------------------------------------------------------
# STORE COMPARISON FRAMEWORK
#------------------------------------------------------------------------------
comparison_framework = {
    'efficiency_auroc': {
        'q20': {'value': r_auc['auroc_q20'], 'ci': r_auc['auroc_q20_ci'], 'reference': 0.84, 'delta_pct': delta_quintile},
        'q25': {'value': r_auc['auroc_q25'], 'ci': r_auc['auroc_q25_ci'], 'reference': 0.70, 'delta_pct': delta_quartile},
    },
    'binary_dr': {
        'auroc': {'value': r_dr['auroc'], 'ci': r_dr['auroc_ci'], 'reference': 0.631, 'delta_pct': delta_dr},
        'prevalence_this': dr_prev
    },
    'mortality': {
        'auroc_overall': {'value': r_m['auroc'], 'ci': r_m['auroc_ci']},
        'auroc_cs': r_m['auroc_cs'],
        'auroc_no_cs': r_m['auroc_no_cs']
    }
}

print("\n" + "="*70)
print("✓ Comparison framework stored in 'comparison_framework' dictionary")
print("="*70)
print("\n→ Next: Cell 13 - Subgroup Analyses")


In [None]:
#==========================================================================
# SECTION 8: SUBGROUP ANALYSES
# Cell 13: Pre-specified Subgroup Analyses
#==========================================================================
# PURPOSE: Evaluate BAN-ADHF performance across pre-specified subgroups
# OUTCOME: 24-hour diuretic efficiency (score's intended purpose)
# METRICS: Spearman correlation (with bootstrap CI), AUROC (Q20) per subgroup
# OUTPUT: Tables, interaction p-values, forest plot data
#==========================================================================

from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score
import numpy as np
import pandas as pd

print("="*70)
print("SUBGROUP ANALYSES: 24-HOUR DIURETIC EFFICIENCY")
print("="*70)
print("\nPrimary outcome: 24-hour diuretic efficiency (BAN-ADHF's intended purpose)")

#------------------------------------------------------------------------------
# PREPARE ANALYSIS COHORT
#------------------------------------------------------------------------------

df_sub = df[(df['icu_stay_ge_24h'] == 1) &
            (df['diuretic_efficiency_24h'].notna()) &
            (df['diuretic_efficiency_24h'] > 0)].copy()

print(f"\nAnalysis cohort: N = {len(df_sub):,}")

# Define lowest quintile for AUROC (Q20)
quintile_threshold = df_sub['diuretic_efficiency_24h'].quantile(0.20)
df_sub['lowest_quintile'] = (df_sub['diuretic_efficiency_24h'] <= quintile_threshold).astype(int)

print(f"Lowest quintile threshold: ≤{quintile_threshold:.1f} mL/mg")

#------------------------------------------------------------------------------
# BOOTSTRAP SETUP
#------------------------------------------------------------------------------
np.random.seed(42)
n_bootstrap = 1000

#------------------------------------------------------------------------------
# DEFINE SUBGROUPS
# NOTE: KEEP YOUR OLD GENDER VAR THAT WORKED IN YOUR NOTEBOOK
#------------------------------------------------------------------------------
subgroups = {
    'Age ≥65 years': {
        'var': 'age_65_or_older',
        'groups': {1: '≥65 years', 0: '<65 years'}
    },
    'Sex': {
        'var': 'gender',  # keep this exactly as your old notebook
        'groups': {'M': 'Male', 'F': 'Female'}
    },
    'Diabetes': {
        'var': 'hx_diabetes',
        'groups': {1: 'Diabetes', 0: 'No diabetes'}
    },
    'Chronic kidney disease': {
        'var': 'chronic_advanced_ckd',
        'groups': {1: 'Advanced CKD', 0: 'No advanced CKD'}
    },
    'Atrial fibrillation': {
        'var': 'hx_atrial_fibrillation',
        'groups': {1: 'AFib', 0: 'No AFib'}
    },
    'Home diuretics': {
        'var': 'on_home_diuretics',
        'groups': {1: 'On home diuretics', 0: 'No home diuretics'}
    },
    'Cardiogenic shock': {
        'var': 'cardiogenic_shock',
        'groups': {1: 'Cardiogenic shock', 0: 'No cardiogenic shock'}
    }
}

#------------------------------------------------------------------------------
# HELPERS
#------------------------------------------------------------------------------

def bootstrap_spearman_ci(data: pd.DataFrame, n_boot: int = 1000):
    """Spearman rho + bootstrap CI"""
    rho, p = spearmanr(data['ban_adhf_total_score'], data['diuretic_efficiency_24h'])

    boot = []
    for _ in range(n_boot):
        idx = np.random.choice(len(data), size=len(data), replace=True)
        sample = data.iloc[idx]
        r, _ = spearmanr(sample['ban_adhf_total_score'], sample['diuretic_efficiency_24h'])
        boot.append(r)

    ci = (np.percentile(boot, 2.5), np.percentile(boot, 97.5))
    return rho, ci, p


def bootstrap_auroc_ci(data: pd.DataFrame, n_boot: int = 1000):
    """AUROC (Q20) + bootstrap CI"""
    # guard for too-few events / no variation
    events = int(data['lowest_quintile'].sum())
    if events <= 5 or events >= len(data) - 5:
        return None, (None, None)

    auc = roc_auc_score(data['lowest_quintile'], data['ban_adhf_total_score'])

    boot = []
    for _ in range(n_boot):
        idx = np.random.choice(len(data), size=len(data), replace=True)
        sample = data.iloc[idx]
        if sample['lowest_quintile'].sum() > 0 and sample['lowest_quintile'].sum() < len(sample):
            boot.append(roc_auc_score(sample['lowest_quintile'], sample['ban_adhf_total_score']))

    ci = (np.percentile(boot, 2.5), np.percentile(boot, 97.5))
    return auc, ci


def fisher_z_test(r1, n1, r2, n2):
    """Compare two correlations using Fisher z-transform (two-tailed)"""
    # Fisher z
    z1 = 0.5 * np.log((1 + r1) / (1 - r1))
    z2 = 0.5 * np.log((1 + r2) / (1 - r2))
    se = np.sqrt(1/(n1-3) + 1/(n2-3))
    z_stat = (z1 - z2) / se

    from scipy.stats import norm
    p_value = 2 * (1 - norm.cdf(abs(z_stat)))
    return z_stat, p_value


def calculate_subgroup_metrics(data: pd.DataFrame, n_boot: int = 1000):
    rho, rho_ci, p = bootstrap_spearman_ci(data, n_boot=n_boot)
    auc, auc_ci = bootstrap_auroc_ci(data, n_boot=n_boot)
    return {
        'n': len(data),
        'n_events': int(data['lowest_quintile'].sum()),
        'spearman_rho': float(rho),
        'spearman_ci': (float(rho_ci[0]), float(rho_ci[1])),
        'spearman_p': float(p),
        'auroc': None if auc is None else float(auc),
        'auroc_ci': (None, None) if auc is None else (float(auc_ci[0]), float(auc_ci[1]))
    }

#------------------------------------------------------------------------------
# OVERALL (for forest plot row)
#------------------------------------------------------------------------------
overall_rho, overall_rho_ci, overall_p = bootstrap_spearman_ci(df_sub, n_boot=n_bootstrap)

#------------------------------------------------------------------------------
# ANALYZE EACH SUBGROUP
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUBGROUP RESULTS: SPEARMAN CORRELATION")
print("="*70)

subgroup_results = {}

print(f"\n{'Subgroup':<30} {'N':<8} {'Spearman ρ':<12} {'95% CI':<22} {'p-value':<10}")
print("-"*85)

for subgroup_name, subgroup_info in subgroups.items():
    var = subgroup_info['var']
    groups = subgroup_info['groups']

    subgroup_results[subgroup_name] = {}

    for value, label in groups.items():
        subset = df_sub[df_sub[var] == value]

        if len(subset) >= 30:
            metrics = calculate_subgroup_metrics(subset, n_boot=n_bootstrap)
            subgroup_results[subgroup_name][label] = metrics

            ci_str = f"({metrics['spearman_ci'][0]:.3f}, {metrics['spearman_ci'][1]:.3f})"
            p_str = "<0.001" if metrics['spearman_p'] < 0.001 else f"{metrics['spearman_p']:.3f}"

            print(f"{label:<30} {metrics['n']:<8} {metrics['spearman_rho']:<12.3f} {ci_str:<22} {p_str:<10}")
        else:
            print(f"{label:<30} {len(subset):<8} {'Insufficient N':<12}")

#------------------------------------------------------------------------------
# INTERACTION TESTS
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("INTERACTION TESTS: DIFFERENCE IN CORRELATIONS")
print("="*70)
print("\nFisher's z-transformation to compare correlations between subgroups")

print(f"\n{'Comparison':<45} {'ρ₁':<8} {'ρ₂':<8} {'Δρ':<10} {'p-interaction':<12}")
print("-"*90)

interaction_results = {}

for subgroup_name, subgroup_info in subgroups.items():
    labels = list(subgroup_info['groups'].values())

    if len(labels) == 2 and all(label in subgroup_results[subgroup_name] for label in labels):
        r1 = subgroup_results[subgroup_name][labels[0]]['spearman_rho']
        n1 = subgroup_results[subgroup_name][labels[0]]['n']
        r2 = subgroup_results[subgroup_name][labels[1]]['spearman_rho']
        n2 = subgroup_results[subgroup_name][labels[1]]['n']

        z_stat, p_int = fisher_z_test(r1, n1, r2, n2)
        delta_rho = r1 - r2

        interaction_results[subgroup_name] = {
            'rho_1': r1, 'rho_2': r2, 'delta': delta_rho,
            'z_stat': z_stat, 'p_interaction': p_int
        }

        p_str = "<0.001" if p_int < 0.001 else f"{p_int:.3f}"
        sig = "*" if p_int < 0.05 else ""
        comparison = f"{labels[0]} vs {labels[1]}"
        print(f"{comparison:<45} {r1:<8.3f} {r2:<8.3f} {delta_rho:<+10.3f} {p_str:<12} {sig}")

#------------------------------------------------------------------------------
# AUROC BY SUBGROUP
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUBGROUP RESULTS: AUROC FOR LOWEST QUINTILE")
print("="*70)

print(f"\n{'Subgroup':<30} {'N':<8} {'Events':<8} {'AUROC':<10} {'95% CI':<22}")
print("-"*90)

for subgroup_name, groups_data in subgroup_results.items():
    for label, metrics in groups_data.items():
        if metrics['auroc'] is not None:
            ci_str = f"({metrics['auroc_ci'][0]:.3f}, {metrics['auroc_ci'][1]:.3f})"
            print(f"{label:<30} {metrics['n']:<8} {metrics['n_events']:<8} {metrics['auroc']:<10.3f} {ci_str:<22}")
        else:
            print(f"{label:<30} {metrics['n']:<8} {metrics['n_events']:<8} {'N/A':<10} {'':<22}")

#------------------------------------------------------------------------------
# HF PHENOTYPE SUBGROUP (3 categories)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("HF PHENOTYPE SUBGROUP ANALYSIS")
print("="*70)

hf_results = {}
df_hf = df_sub[df_sub['hf_phenotype'].notna()].copy()

print(f"\nPatients with HF phenotype data: N = {len(df_hf):,}")
print(f"\n{'HF Phenotype':<20} {'N':<8} {'Spearman ρ':<12} {'95% CI':<22} {'AUROC':<10}")
print("-"*80)

for phenotype in ['HFrEF', 'HFmrEF', 'HFpEF']:
    subset = df_hf[df_hf['hf_phenotype'] == phenotype]

    if len(subset) >= 30:
        metrics = calculate_subgroup_metrics(subset, n_boot=n_bootstrap)
        hf_results[phenotype] = metrics

        ci_str = f"({metrics['spearman_ci'][0]:.3f}, {metrics['spearman_ci'][1]:.3f})"
        auroc_str = f"{metrics['auroc']:.3f}" if metrics['auroc'] is not None else "N/A"
        print(f"{phenotype:<20} {metrics['n']:<8} {metrics['spearman_rho']:<12.3f} {ci_str:<22} {auroc_str:<10}")

#------------------------------------------------------------------------------
# FOREST PLOT DATA (FOR VISUALIZATION)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("FOREST PLOT DATA (FOR VISUALIZATION)")
print("="*70)

forest_data = []

# Overall (computed directly so no dependency on results_24h dict shape)
forest_data.append({
    'subgroup': 'Overall',
    'n': len(df_sub),
    'rho': float(overall_rho),
    'rho_ci_low': float(overall_rho_ci[0]),
    'rho_ci_high': float(overall_rho_ci[1])
})

# Binary subgroups
for subgroup_name, groups_data in subgroup_results.items():
    for label, metrics in groups_data.items():
        forest_data.append({
            'subgroup': label,
            'n': metrics['n'],
            'rho': metrics['spearman_rho'],
            'rho_ci_low': metrics['spearman_ci'][0],
            'rho_ci_high': metrics['spearman_ci'][1]
        })

# HF phenotypes
for phenotype, metrics in hf_results.items():
    forest_data.append({
        'subgroup': phenotype,
        'n': metrics['n'],
        'rho': metrics['spearman_rho'],
        'rho_ci_low': metrics['spearman_ci'][0],
        'rho_ci_high': metrics['spearman_ci'][1]
    })

forest_df = pd.DataFrame(forest_data)
print("\n")
print(forest_df.to_string(index=False))

#------------------------------------------------------------------------------
# KEY FINDINGS
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("KEY SUBGROUP FINDINGS")
print("="*70)

sig_interactions = {k: v for k, v in interaction_results.items() if v['p_interaction'] < 0.05}

if sig_interactions:
    print("\nSignificant interactions (p < 0.05):")
    for name, data in sig_interactions.items():
        p_str = "<0.001" if data['p_interaction'] < 0.001 else f"{data['p_interaction']:.3f}"
        print(f"  • {name}: Δρ = {data['delta']:+.3f}, p = {p_str}")
else:
    print("\nNo significant interactions detected (p < 0.05)")
    print("BAN-ADHF performance is consistent across subgroups")

# strongest/weakest by |rho|
all_rhos = []
for subgroup_name, groups_data in subgroup_results.items():
    for label, metrics in groups_data.items():
        all_rhos.append((label, metrics['spearman_rho'], metrics['n']))

all_rhos.sort(key=lambda x: abs(x[1]), reverse=True)

print(f"\nLargest |association|:  {all_rhos[0][0]} (ρ = {all_rhos[0][1]:.3f}, N={all_rhos[0][2]})")
print(f"Smallest |association|: {all_rhos[-1][0]} (ρ = {all_rhos[-1][1]:.3f}, N={all_rhos[-1][2]})")

#------------------------------------------------------------------------------
# STORE RESULTS
#------------------------------------------------------------------------------
results_subgroups = {
    'subgroup_results': subgroup_results,
    'interaction_results': interaction_results,
    'hf_results': hf_results,
    'forest_data': forest_df
}

#------------------------------------------------------------------------------
# SUMMARY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUMMARY: SUBGROUP ANALYSES")
print("="*70)

print(f"""
Analysis: 24-hour diuretic efficiency (BAN-ADHF's intended purpose)
Cohort: N = {len(df_sub):,}

Key Findings:
  • Overall Spearman ρ = {overall_rho:.3f} (95% CI: {overall_rho_ci[0]:.3f} to {overall_rho_ci[1]:.3f})
  • {len(sig_interactions)} significant interaction(s) detected
""")

print("✓ Subgroup results stored in 'results_subgroups' dictionary")
print("\n→ Next: Sensitivity Analyses")


In [None]:
#==========================================================================
# SECTION 9: SENSITIVITY ANALYSES
# Cell 14: Robustness Testing
#==========================================================================
# PURPOSE: Test robustness of primary findings under various conditions
# NOTES:
#   - Uses FIXED Q20 threshold from primary cohort for all analyses
#   - Ensures consistency with eTable 3 (Subgroup Analysis)
#   - Computes the primary reference metrics directly from df_base
#==========================================================================

from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score
import numpy as np
import pandas as pd

print("="*70)
print("SENSITIVITY ANALYSES")
print("="*70)
print("\nTesting robustness of primary findings under various conditions")

np.random.seed(42)
n_bootstrap = 1000

#------------------------------------------------------------------------------
# BASE COHORT
#------------------------------------------------------------------------------
df_base = df[(df['icu_stay_ge_24h'] == 1) &
             (df['diuretic_efficiency_24h'].notna()) &
             (df['diuretic_efficiency_24h'] > 0)].copy()

print(f"\nBase cohort (24h efficiency): N = {len(df_base):,}")

#------------------------------------------------------------------------------
# FIXED THRESHOLDS FROM PRIMARY COHORT
#------------------------------------------------------------------------------
FIXED_Q20_THRESHOLD = df_base['diuretic_efficiency_24h'].quantile(0.20)
FIXED_Q25_THRESHOLD = df_base['diuretic_efficiency_24h'].quantile(0.25)

print(f"Q20 threshold: ≤{FIXED_Q20_THRESHOLD:.1f} mL/mg")
print(f"Q25 threshold: ≤{FIXED_Q25_THRESHOLD:.1f} mL/mg")

#------------------------------------------------------------------------------
# BOOTSTRAP HELPERS
#------------------------------------------------------------------------------
def bootstrap_spearman_ci(data, n_boot=1000):
    rho, p = spearmanr(data['ban_adhf_total_score'], data['diuretic_efficiency_24h'])
    boot = []
    for _ in range(n_boot):
        idx = np.random.choice(len(data), size=len(data), replace=True)
        sample = data.iloc[idx]
        r, _ = spearmanr(sample['ban_adhf_total_score'], sample['diuretic_efficiency_24h'])
        boot.append(r)
    ci = (np.percentile(boot, 2.5), np.percentile(boot, 97.5))
    return float(rho), (float(ci[0]), float(ci[1])), float(p)

def bootstrap_auroc_ci(y, score, n_boot=1000):
    y = pd.Series(y).reset_index(drop=True)
    score = pd.Series(score).reset_index(drop=True)

    if y.sum() == 0 or y.sum() == len(y):
        return None, (None, None)

    auc = roc_auc_score(y, score)
    boot = []
    for _ in range(n_boot):
        idx = np.random.choice(len(y), size=len(y), replace=True)
        yb = y.iloc[idx]
        sb = score.iloc[idx]
        if yb.sum() > 0 and yb.sum() < len(yb):
            boot.append(roc_auc_score(yb, sb))

    ci = (np.percentile(boot, 2.5), np.percentile(boot, 97.5))
    return float(auc), (float(ci[0]), float(ci[1]))

def compute_primary_metrics(data, q20_thr, q25_thr):
    rho, rho_ci, p = bootstrap_spearman_ci(data, n_boot=n_bootstrap)

    # Q20 using fixed threshold
    y20 = (data['diuretic_efficiency_24h'] <= q20_thr).astype(int)
    auc20, auc20_ci = bootstrap_auroc_ci(y20, data['ban_adhf_total_score'], n_boot=n_bootstrap)

    # Q25 using fixed threshold
    y25 = (data['diuretic_efficiency_24h'] <= q25_thr).astype(int)
    auc25, auc25_ci = bootstrap_auroc_ci(y25, data['ban_adhf_total_score'], n_boot=n_bootstrap)

    return {
        'n': int(len(data)),
        'spearman_rho': rho,
        'spearman_ci': rho_ci,
        'spearman_p': p,
        'q20_threshold': float(q20_thr),
        'q25_threshold': float(q25_thr),
        'auroc_q20': auc20,
        'auroc_q20_ci': auc20_ci,
        'auroc_q25': auc25,
        'auroc_q25_ci': auc25_ci,
        'n_q20': int(y20.sum()),
        'n_q25': int(y25.sum()),
    }

#------------------------------------------------------------------------------
# REFERENCE VALUES (Primary Analysis)
#------------------------------------------------------------------------------
primary_24h = compute_primary_metrics(df_base, FIXED_Q20_THRESHOLD, FIXED_Q25_THRESHOLD)

print("\n" + "-"*70)
print("REFERENCE: PRIMARY ANALYSIS RESULTS")
print("-"*70)

print(f"""
Primary cohort (24h efficiency): N = {primary_24h['n']:,}
  Spearman ρ:      {primary_24h['spearman_rho']:.3f} ({primary_24h['spearman_ci'][0]:.3f} to {primary_24h['spearman_ci'][1]:.3f})
  AUROC (Q20):     {primary_24h['auroc_q20']:.3f} ({primary_24h['auroc_q20_ci'][0]:.3f} to {primary_24h['auroc_q20_ci'][1]:.3f})
  AUROC (Q25):     {primary_24h['auroc_q25']:.3f} ({primary_24h['auroc_q25_ci'][0]:.3f} to {primary_24h['auroc_q25_ci'][1]:.3f})
""")

# Optional: keep your Step 12 dict aligned if it exists
try:
    results_24h_auc = results_24h_auc  # noqa
    results_24h_auc.update({
        'n': primary_24h['n'],
        'auroc_q20': primary_24h['auroc_q20'],
        'auroc_q25': primary_24h['auroc_q25'],
        'auroc_q20_ci': primary_24h['auroc_q20_ci'],
        'auroc_q25_ci': primary_24h['auroc_q25_ci'],
    })
except NameError:
    pass

#------------------------------------------------------------------------------
# SENSITIVITY RUN HELPER (uses fixed threshold)
#------------------------------------------------------------------------------
def sensitivity_run(data, name, q20_thr):
    out = {'name': name, 'n': int(len(data))}
    if len(data) < 50:
        out.update({'spearman': None, 'spearman_ci': (None, None), 'auroc': None, 'auroc_ci': (None, None),
                    'n_events': None, 'note': 'Insufficient N'})
        return out

    rho, rho_ci, p = bootstrap_spearman_ci(data, n_boot=n_bootstrap)
    out['spearman'] = rho
    out['spearman_ci'] = rho_ci
    out['spearman_p'] = p

    # Use FIXED threshold from primary cohort
    y = (data['diuretic_efficiency_24h'] <= q20_thr).astype(int)
    out['threshold'] = float(q20_thr)
    out['n_events'] = int(y.sum())

    auc, auc_ci = bootstrap_auroc_ci(y, data['ban_adhf_total_score'], n_boot=n_bootstrap)
    out['auroc'] = auc
    out['auroc_ci'] = auc_ci
    out['note'] = ''
    return out

#------------------------------------------------------------------------------
# 1) EXCLUDE ADVANCED CKD
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("1. EXCLUDE ADVANCED CKD (eGFR <30 mL/min/1.73m²)")
print("="*70)

df_no_ckd = df_base[df_base['chronic_advanced_ckd'] == 0].copy()
sa1 = sensitivity_run(df_no_ckd, "Exclude advanced CKD", FIXED_Q20_THRESHOLD)

print(f"\nExcluded: {len(df_base) - len(df_no_ckd)} patients with advanced CKD")
print(f"Remaining: N = {sa1['n']:,}")
print(f"Events (efficiency ≤{FIXED_Q20_THRESHOLD:.1f} mL/mg): {sa1['n_events']}")
print(f"Spearman ρ = {sa1['spearman']:.3f} ({sa1['spearman_ci'][0]:.3f} to {sa1['spearman_ci'][1]:.3f})")
print(f"AUROC (Q20) = {sa1['auroc']:.3f} ({sa1['auroc_ci'][0]:.3f} to {sa1['auroc_ci'][1]:.3f})")

delta_rho_ckd = sa1['spearman'] - primary_24h['spearman_rho']
delta_auroc_ckd = sa1['auroc'] - primary_24h['auroc_q20']
print(f"\nChange from primary: Δρ = {delta_rho_ckd:+.3f}, ΔAUROC = {delta_auroc_ckd:+.3f}")

print("""
NOTE:
  Creatinine is a BAN-ADHF component and also a determinant of diuretic response.
  Excluding advanced CKD reduces score variability (range restriction), which can attenuate correlation.
""")

#------------------------------------------------------------------------------
# 2) EXCLUDE EXTREME OUTLIERS
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("2. EXCLUDE EXTREME EFFICIENCY OUTLIERS (>99th percentile)")
print("="*70)

p99 = df_base['diuretic_efficiency_24h'].quantile(0.99)
df_no_outliers = df_base[df_base['diuretic_efficiency_24h'] <= p99].copy()
sa2 = sensitivity_run(df_no_outliers, "Exclude outliers >P99", FIXED_Q20_THRESHOLD)

print(f"\n99th percentile threshold: {p99:.1f} mL/mg")
print(f"Excluded: {len(df_base) - len(df_no_outliers)} extreme outliers")
print(f"Remaining: N = {sa2['n']:,}")
print(f"Events (efficiency ≤{FIXED_Q20_THRESHOLD:.1f} mL/mg): {sa2['n_events']}")
print(f"Spearman ρ = {sa2['spearman']:.3f} ({sa2['spearman_ci'][0]:.3f} to {sa2['spearman_ci'][1]:.3f})")
print(f"AUROC (Q20) = {sa2['auroc']:.3f} ({sa2['auroc_ci'][0]:.3f} to {sa2['auroc_ci'][1]:.3f})")

#------------------------------------------------------------------------------
# 3) EXCLUDE CARDIOGENIC SHOCK
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("3. EXCLUDE CARDIOGENIC SHOCK PATIENTS")
print("="*70)

df_no_cs = df_base[df_base['cardiogenic_shock'] == 0].copy()
sa3 = sensitivity_run(df_no_cs, "Exclude cardiogenic shock", FIXED_Q20_THRESHOLD)

print(f"\nExcluded: {len(df_base) - len(df_no_cs)} patients with cardiogenic shock")
print(f"Remaining: N = {sa3['n']:,}")
print(f"Events (efficiency ≤{FIXED_Q20_THRESHOLD:.1f} mL/mg): {sa3['n_events']}")
print(f"Spearman ρ = {sa3['spearman']:.3f} ({sa3['spearman_ci'][0]:.3f} to {sa3['spearman_ci'][1]:.3f})")
print(f"AUROC (Q20) = {sa3['auroc']:.3f} ({sa3['auroc_ci'][0]:.3f} to {sa3['auroc_ci'][1]:.3f})")

#------------------------------------------------------------------------------
# 4) RESTRICT TO COMPLETE LVEF DATA
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("4. RESTRICT TO PATIENTS WITH LVEF DATA")
print("="*70)

df_lvef = df_base[df_base['lvef'].notna()].copy()
sa4 = sensitivity_run(df_lvef, "Complete LVEF data", FIXED_Q20_THRESHOLD)

print(f"\nExcluded: {len(df_base) - len(df_lvef)} patients without LVEF")
print(f"Remaining: N = {sa4['n']:,}")
print(f"Events (efficiency ≤{FIXED_Q20_THRESHOLD:.1f} mL/mg): {sa4['n_events']}")
print(f"Spearman ρ = {sa4['spearman']:.3f} ({sa4['spearman_ci'][0]:.3f} to {sa4['spearman_ci'][1]:.3f})")
print(f"AUROC (Q20) = {sa4['auroc']:.3f} ({sa4['auroc_ci'][0]:.3f} to {sa4['auroc_ci'][1]:.3f})")

#------------------------------------------------------------------------------
# 5) ALTERNATIVE EFFICIENCY THRESHOLDS
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("5. ALTERNATIVE EFFICIENCY THRESHOLDS FOR AUROC")
print("="*70)

print(f"\n{'Percentile':<10} {'Definition':<18} {'Threshold':<14} {'Events':<8} {'AUROC':<10} {'95% CI':<18}")
print("-"*85)

threshold_results = []
for percentile, label in [
    (0.10, "Lowest 10%"),
    (0.15, "Lowest 15%"),
    (0.20, "Lowest 20% (Q20)"),
    (0.25, "Lowest 25% (Q25)"),
    (0.30, "Lowest 30%"),
    (0.33, "Lowest 33%")
]:
    thr = df_base['diuretic_efficiency_24h'].quantile(percentile)
    y = (df_base['diuretic_efficiency_24h'] <= thr).astype(int)
    events = int(y.sum())

    auc, auc_ci = bootstrap_auroc_ci(y, df_base['ban_adhf_total_score'], n_boot=n_bootstrap)
    threshold_results.append({
        'percentile': percentile,
        'label': label,
        'threshold': float(thr),
        'n_events': events,
        'auroc': auc,
        'auroc_ci': auc_ci
    })

    marker = "  <- primary" if percentile == 0.20 else ""
    print(f"{int(percentile*100):<10} {label:<18} {thr:<14.1f} {events:<8} {auc:<10.3f} ({auc_ci[0]:.3f}–{auc_ci[1]:.3f}){marker}")

aucs = [r['auroc'] for r in threshold_results if r['auroc'] is not None]
auroc_threshold_range = max(aucs) - min(aucs)
print(f"\nAUROC range across thresholds: {min(aucs):.3f} to {max(aucs):.3f} (spread: {auroc_threshold_range:.3f})")

#------------------------------------------------------------------------------
# 6) ALTERNATIVE DIURETIC RESISTANCE DEFINITIONS
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("6. ALTERNATIVE DIURETIC RESISTANCE DEFINITIONS")
print("="*70)

df_dr_base = df[df['icu_stay_ge_24h'] == 1].copy()
print(f"\nCohort: ICU ≥24h (N = {len(df_dr_base):,})")

print(f"\n{'Definition':<20} {'Resistant':<10} {'Prev':<8} {'AUROC':<10} {'95% CI':<18}")
print("-"*75)

dr_results = []
for thr_uop in [2000, 2500, 3000, 3500, 4000]:
    y = (df_dr_base['urine_output_24h_ml'] <= thr_uop).astype(int)
    resistant = int(y.sum())
    prev = 100 * resistant / len(df_dr_base)

    auc, auc_ci = bootstrap_auroc_ci(y, df_dr_base['ban_adhf_total_score'], n_boot=n_bootstrap)
    dr_results.append({
        'threshold_ml': thr_uop,
        'n_resistant': resistant,
        'prevalence_pct': float(prev),
        'auroc': auc,
        'auroc_ci': auc_ci
    })

    marker = "  <- primary" if thr_uop == 3000 else ""
    print(f"UOP ≤{thr_uop:<12} {resistant:<10} {prev:<8.1f} {auc:<10.3f} ({auc_ci[0]:.3f}–{auc_ci[1]:.3f}){marker}")

dr_aucs = [r['auroc'] for r in dr_results if r['auroc'] is not None]
dr_auroc_range = max(dr_aucs) - min(dr_aucs)
print(f"\nBinary DR AUROC range across thresholds: {min(dr_aucs):.3f} to {max(dr_aucs):.3f} (spread: {dr_auroc_range:.3f})")

#------------------------------------------------------------------------------
# SUMMARY TABLE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SENSITIVITY ANALYSIS SUMMARY TABLE")
print("="*70)

def fmt_ci(ci):
    if ci[0] is None:
        return "NA"
    return f"({ci[0]:.3f}–{ci[1]:.3f})"

rows = [
    ("Primary analysis", primary_24h['n'], primary_24h['spearman_rho'], primary_24h['spearman_ci'], primary_24h['auroc_q20'], primary_24h['auroc_q20_ci']),
    ("Exclude advanced CKD*", sa1['n'], sa1['spearman'], sa1['spearman_ci'], sa1['auroc'], sa1['auroc_ci']),
    ("Exclude outliers (>P99)", sa2['n'], sa2['spearman'], sa2['spearman_ci'], sa2['auroc'], sa2['auroc_ci']),
    ("Exclude cardiogenic shock", sa3['n'], sa3['spearman'], sa3['spearman_ci'], sa3['auroc'], sa3['auroc_ci']),
    ("Complete LVEF data only", sa4['n'], sa4['spearman'], sa4['spearman_ci'], sa4['auroc'], sa4['auroc_ci']),
]

print(f"\n{'Analysis':<28} {'N':<8} {'Spearman ρ':<12} {'95% CI':<18} {'AUROC Q20':<10} {'95% CI':<18} {'Δρ':<8} {'ΔAUC':<8}")
print("-"*115)

for name, n, rho, rho_ci, auc, auc_ci in rows:
    d_rho = rho - primary_24h['spearman_rho']
    d_auc = auc - primary_24h['auroc_q20']
    print(f"{name:<28} {n:<8} {rho:<12.3f} {fmt_ci(rho_ci):<18} {auc:<10.3f} {fmt_ci(auc_ci):<18} {d_rho:<+8.3f} {d_auc:<+8.3f}")

print("""
*CKD exclusion attenuation is expected due to range restriction, as creatinine 
 is both a score component and independent predictor of diuretic efficiency.
""")

#------------------------------------------------------------------------------
# ROBUSTNESS ASSESSMENT
#------------------------------------------------------------------------------
robust_rhos = [sa2['spearman'], sa3['spearman'], sa4['spearman']]
robust_aucs = [sa2['auroc'], sa3['auroc'], sa4['auroc']]

rho_spread = max(robust_rhos) - min(robust_rhos)
auc_spread = max(robust_aucs) - min(robust_aucs)

print("\n" + "="*70)
print("ROBUSTNESS ASSESSMENT")
print("="*70)

print(f"""
Excluding outliers, excluding cardiogenic shock, and restricting to LVEF-complete:
  Spearman ρ spread: {rho_spread:.3f}
  AUROC Q20 spread:  {auc_spread:.3f}

Threshold stability:
  Efficiency AUROC spread (10% to 33%): {auroc_threshold_range:.3f}
  Binary DR AUROC spread (UOP 2000 to 4000): {dr_auroc_range:.3f}
""")

#------------------------------------------------------------------------------
# STORE RESULTS
#------------------------------------------------------------------------------
results_sensitivity = {
    'fixed_q20_threshold': float(FIXED_Q20_THRESHOLD),
    'fixed_q25_threshold': float(FIXED_Q25_THRESHOLD),
    'primary_24h': primary_24h,
    'exclude_ckd': sa1,
    'exclude_outliers': sa2,
    'exclude_cs': sa3,
    'complete_lvef': sa4,
    'efficiency_thresholds': threshold_results,
    'dr_thresholds': dr_results,
    'robustness_excluding_ckd': {
        'rho_spread': float(rho_spread),
        'auc_spread': float(auc_spread),
        'note': "Excludes CKD run because range restriction can attenuate correlation by design."
    }
}

print("\n" + "="*70)
print("✓ Sensitivity analysis results stored in 'results_sensitivity' dictionary")
print("="*70)
print("\n→ Next: Secondary Outcomes (vasopressor, inotrope, MCS, ventilation, LOS)")

In [None]:
#==========================================================================
# SECTION 10: SECONDARY OUTCOMES
# Cell 15: ICU Interventions and Length of Stay
#==========================================================================
# PURPOSE: Evaluate BAN-ADHF association with exploratory secondary ICU outcomes
# OUTCOMES:
#   1. Vasopressor use
#   2. Inotrope use
#   3. Mechanical circulatory support (MCS)
#   4. Invasive mechanical ventilation
#   5. ICU length of stay
#   6. Hospital length of stay
# OUTPUT:
#   - Binary outcomes: prevalence, AUROC (bootstrap 95% CI), risk-category rates, chi-square p
#   - Continuous outcomes: Spearman rho (bootstrap 95% CI), Kruskal-Wallis across risk categories
#==========================================================================

from scipy.stats import spearmanr, chi2_contingency, kruskal
from sklearn.metrics import roc_auc_score
import numpy as np
import pandas as pd

print("="*70)
print("SECONDARY OUTCOMES ANALYSIS")
print("="*70)
print("\nExploratory outcomes: ICU resource utilization and severity markers.")
print("BAN-ADHF was not designed for these outcomes.")

#------------------------------------------------------------------------------
# ANALYSIS COHORT (Full cohort)
#------------------------------------------------------------------------------
df_sec = df.copy()
print(f"\nAnalysis cohort: N = {len(df_sec):,}")

#------------------------------------------------------------------------------
# BOOTSTRAP SETUP
#------------------------------------------------------------------------------
np.random.seed(42)
n_bootstrap = 1000

#------------------------------------------------------------------------------
# UTILITIES
#------------------------------------------------------------------------------
def format_p(p_val):
    if p_val is None or (isinstance(p_val, float) and np.isnan(p_val)):
        return "NA"
    if p_val < 0.001:
        return "<0.001"
    return f"{p_val:.3f}"

def bootstrap_auroc_ci(y, score, n_boot=1000):
    y = pd.Series(y).reset_index(drop=True)
    score = pd.Series(score).reset_index(drop=True)

    if y.sum() == 0 or y.sum() == len(y):
        return None, (None, None)

    auc = roc_auc_score(y, score)

    boot = []
    for _ in range(n_boot):
        idx = np.random.choice(len(y), size=len(y), replace=True)
        yb = y.iloc[idx]
        sb = score.iloc[idx]
        if yb.sum() > 0 and yb.sum() < len(yb):
            boot.append(roc_auc_score(yb, sb))

    ci = (np.percentile(boot, 2.5), np.percentile(boot, 97.5))
    return float(auc), (float(ci[0]), float(ci[1]))

def bootstrap_spearman_ci(x, y, n_boot=1000):
    rho, p = spearmanr(x, y)
    boot = []
    x = pd.Series(x).reset_index(drop=True)
    y = pd.Series(y).reset_index(drop=True)

    for _ in range(n_boot):
        idx = np.random.choice(len(x), size=len(x), replace=True)
        r, _ = spearmanr(x.iloc[idx], y.iloc[idx])
        boot.append(r)

    ci = (np.percentile(boot, 2.5), np.percentile(boot, 97.5))
    return float(rho), (float(ci[0]), float(ci[1])), float(p)

def safe_chi_square(risk_series, outcome_series):
    """
    Returns (chi2, p_value) or (None, None) if test cannot be computed.
    Handles cases where a category has zero count or only one outcome level.
    """
    tab = pd.crosstab(risk_series, outcome_series)

    # Need at least 2 rows and 2 columns for chi-square
    if tab.shape[0] < 2 or tab.shape[1] < 2:
        return None, None

    try:
        chi2, p, dof, expected = chi2_contingency(tab)
        return float(chi2), float(p)
    except Exception:
        return None, None

def analyze_binary_outcome(data, outcome_var, outcome_label):
    """
    Binary outcome analysis:
      - filters to rows with risk_category and ban_adhf_total_score present
      - coerces outcome to 0/1 numeric
      - AUROC with bootstrap CI if feasible
      - rates by risk category
      - chi-square test if feasible
    """
    d = data.copy()

    needed_cols = ['ban_adhf_total_score', 'risk_category', outcome_var]
    for c in needed_cols:
        if c not in d.columns:
            raise KeyError(f"Missing required column: {c}")

    d = d[d['ban_adhf_total_score'].notna() & d['risk_category'].notna()].copy()

    # Coerce outcome to numeric 0/1
    d[outcome_var] = pd.to_numeric(d[outcome_var], errors='coerce')
    d = d[d[outcome_var].notna()].copy()

    # If values are not strictly 0/1 but close, force to 0/1
    d[outcome_var] = (d[outcome_var] > 0).astype(int)

    n_total = int(len(d))
    n_events = int(d[outcome_var].sum())
    prevalence = 100 * n_events / n_total if n_total > 0 else 0.0

    res = {
        'outcome': outcome_label,
        'n': n_total,
        'n_events': n_events,
        'prevalence': float(prevalence),
        'auroc': None,
        'auroc_ci': (None, None),
        'rates_by_risk': {},
        'chi2': None,
        'p_value': None,
        'rr_high_vs_low': None
    }

    # Rates by risk category
    for cat in ['Low', 'Moderate', 'High']:
        sub = d[d['risk_category'] == cat]
        n = int(len(sub))
        n_pos = int(sub[outcome_var].sum()) if n > 0 else 0
        rate = 100 * n_pos / n if n > 0 else np.nan
        res['rates_by_risk'][cat] = {'n': n, 'n_events': n_pos, 'rate': float(rate) if not np.isnan(rate) else np.nan}

    # AUROC if feasible
    if n_total >= 50 and n_events > 10 and n_events < (n_total - 10):
        auc, auc_ci = bootstrap_auroc_ci(d[outcome_var], d['ban_adhf_total_score'], n_boot=n_bootstrap)
        res['auroc'] = auc
        res['auroc_ci'] = auc_ci

    # Chi-square if feasible
    chi2, p = safe_chi_square(d['risk_category'], d[outcome_var])
    res['chi2'] = chi2
    res['p_value'] = p

    # Risk ratio (High vs Low)
    low_rate = res['rates_by_risk']['Low']['rate']
    high_rate = res['rates_by_risk']['High']['rate']
    if low_rate is not None and not np.isnan(low_rate) and low_rate > 0 and high_rate is not None and not np.isnan(high_rate):
        res['rr_high_vs_low'] = float(high_rate / low_rate)

    return res

#------------------------------------------------------------------------------
# 1. VASOPRESSOR USE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("1. VASOPRESSOR USE")
print("="*70)

vaso_results = analyze_binary_outcome(df_sec, 'vasopressor_use', 'Vasopressor use')

print(f"\nPrevalence: {vaso_results['n_events']}/{vaso_results['n']} ({vaso_results['prevalence']:.1f}%)")
if vaso_results['auroc'] is not None:
    print(f"AUROC: {vaso_results['auroc']:.3f} ({vaso_results['auroc_ci'][0]:.3f} to {vaso_results['auroc_ci'][1]:.3f})")
else:
    print("AUROC: N/A (insufficient events or no variation)")

print("\nRates by Risk Category:")
for cat in ['Low', 'Moderate', 'High']:
    r = vaso_results['rates_by_risk'][cat]
    rate_str = "NA" if np.isnan(r['rate']) else f"{r['rate']:.1f}%"
    print(f"  {cat}: {r['n_events']}/{r['n']} ({rate_str})")

print(f"\nChi-square: {vaso_results['chi2']:.1f}" if vaso_results['chi2'] is not None else "\nChi-square: NA")
print(f"p = {format_p(vaso_results['p_value'])}")
if vaso_results['rr_high_vs_low'] is not None:
    print(f"Risk ratio (High vs Low): {vaso_results['rr_high_vs_low']:.2f}")

#------------------------------------------------------------------------------
# 2. INOTROPE USE
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("2. INOTROPE USE")
print("="*70)

ino_results = analyze_binary_outcome(df_sec, 'inotrope_use', 'Inotrope use')

print(f"\nPrevalence: {ino_results['n_events']}/{ino_results['n']} ({ino_results['prevalence']:.1f}%)")
if ino_results['auroc'] is not None:
    print(f"AUROC: {ino_results['auroc']:.3f} ({ino_results['auroc_ci'][0]:.3f} to {ino_results['auroc_ci'][1]:.3f})")
else:
    print("AUROC: N/A (insufficient events or no variation)")

print("\nRates by Risk Category:")
for cat in ['Low', 'Moderate', 'High']:
    r = ino_results['rates_by_risk'][cat]
    rate_str = "NA" if np.isnan(r['rate']) else f"{r['rate']:.1f}%"
    print(f"  {cat}: {r['n_events']}/{r['n']} ({rate_str})")

print(f"\nChi-square: {ino_results['chi2']:.1f}" if ino_results['chi2'] is not None else "\nChi-square: NA")
print(f"p = {format_p(ino_results['p_value'])}")
if ino_results['rr_high_vs_low'] is not None:
    print(f"Risk ratio (High vs Low): {ino_results['rr_high_vs_low']:.2f}")

#------------------------------------------------------------------------------
# 3. MECHANICAL CIRCULATORY SUPPORT (MCS)
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("3. MECHANICAL CIRCULATORY SUPPORT (MCS)")
print("="*70)

mcs_results = analyze_binary_outcome(df_sec, 'mcs_use', 'MCS use')

print(f"\nPrevalence: {mcs_results['n_events']}/{mcs_results['n']} ({mcs_results['prevalence']:.1f}%)")
if mcs_results['auroc'] is not None:
    print(f"AUROC: {mcs_results['auroc']:.3f} ({mcs_results['auroc_ci'][0]:.3f} to {mcs_results['auroc_ci'][1]:.3f})")
else:
    print("AUROC: N/A (insufficient events or no variation)")

print("\nRates by Risk Category:")
for cat in ['Low', 'Moderate', 'High']:
    r = mcs_results['rates_by_risk'][cat]
    rate_str = "NA" if np.isnan(r['rate']) else f"{r['rate']:.1f}%"
    print(f"  {cat}: {r['n_events']}/{r['n']} ({rate_str})")

print(f"\nChi-square: {mcs_results['chi2']:.1f}" if mcs_results['chi2'] is not None else "\nChi-square: NA")
print(f"p = {format_p(mcs_results['p_value'])}")

#------------------------------------------------------------------------------
# 4. INVASIVE MECHANICAL VENTILATION
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("4. INVASIVE MECHANICAL VENTILATION")
print("="*70)

vent_results = analyze_binary_outcome(df_sec, 'invasive_vent', 'Invasive ventilation')

print(f"\nPrevalence: {vent_results['n_events']}/{vent_results['n']} ({vent_results['prevalence']:.1f}%)")
if vent_results['auroc'] is not None:
    print(f"AUROC: {vent_results['auroc']:.3f} ({vent_results['auroc_ci'][0]:.3f} to {vent_results['auroc_ci'][1]:.3f})")
else:
    print("AUROC: N/A (insufficient events or no variation)")

print("\nRates by Risk Category:")
for cat in ['Low', 'Moderate', 'High']:
    r = vent_results['rates_by_risk'][cat]
    rate_str = "NA" if np.isnan(r['rate']) else f"{r['rate']:.1f}%"
    print(f"  {cat}: {r['n_events']}/{r['n']} ({rate_str})")

print(f"\nChi-square: {vent_results['chi2']:.1f}" if vent_results['chi2'] is not None else "\nChi-square: NA")
print(f"p = {format_p(vent_results['p_value'])}")
if vent_results['rr_high_vs_low'] is not None:
    print(f"Risk ratio (High vs Low): {vent_results['rr_high_vs_low']:.2f}")

#------------------------------------------------------------------------------
# 5. ICU LENGTH OF STAY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("5. ICU LENGTH OF STAY")
print("="*70)

df_los = df_sec[df_sec['icu_los_days'].notna() & df_sec['ban_adhf_total_score'].notna() & df_sec['risk_category'].notna()].copy()
rho_icu, rho_icu_ci, p_icu = bootstrap_spearman_ci(df_los['ban_adhf_total_score'], df_los['icu_los_days'], n_boot=n_bootstrap)

print(f"\nN = {len(df_los):,}")
print(f"Spearman ρ = {rho_icu:.3f} ({rho_icu_ci[0]:.3f} to {rho_icu_ci[1]:.3f})")
print(f"p = {format_p(p_icu)}")

print("\nICU LOS by Risk Category:")
print(f"{'Category':<12} {'N':<8} {'Median (IQR)':<18} {'Mean ± SD':<14}")
print("-"*60)

icu_los_by_risk = {}
groups_icu = []
for cat in ['Low', 'Moderate', 'High']:
    s = df_los[df_los['risk_category'] == cat]['icu_los_days'].dropna()
    n = int(len(s))
    median = float(s.median()) if n > 0 else np.nan
    q1 = float(s.quantile(0.25)) if n > 0 else np.nan
    q3 = float(s.quantile(0.75)) if n > 0 else np.nan
    mean = float(s.mean()) if n > 0 else np.nan
    std = float(s.std()) if n > 1 else np.nan
    icu_los_by_risk[cat] = {'n': n, 'median': median, 'iqr': (q1, q3), 'mean': mean, 'std': std}
    groups_icu.append(s.values)

    med_str = "NA" if np.isnan(median) else f"{median:.1f} ({q1:.1f} to {q3:.1f})"
    mean_str = "NA" if np.isnan(mean) else f"{mean:.1f} ± {std:.1f}"
    print(f"{cat:<12} {n:<8} {med_str:<18} {mean_str:<14}")

# Kruskal-Wallis only if at least 2 groups have data
nonempty = [g for g in groups_icu if len(g) > 0]
if len(nonempty) >= 2:
    kw_icu, p_kw_icu = kruskal(*nonempty)
    print(f"\nKruskal-Wallis H = {kw_icu:.1f}, p = {format_p(p_kw_icu)}")
else:
    kw_icu, p_kw_icu = None, None
    print("\nKruskal-Wallis: NA (insufficient group data)")

#------------------------------------------------------------------------------
# 6. HOSPITAL LENGTH OF STAY
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("6. HOSPITAL LENGTH OF STAY")
print("="*70)

df_hlos = df_sec[df_sec['hospital_los_days'].notna() & df_sec['ban_adhf_total_score'].notna() & df_sec['risk_category'].notna()].copy()
rho_hosp, rho_hosp_ci, p_hosp = bootstrap_spearman_ci(df_hlos['ban_adhf_total_score'], df_hlos['hospital_los_days'], n_boot=n_bootstrap)

print(f"\nN = {len(df_hlos):,}")
print(f"Spearman ρ = {rho_hosp:.3f} ({rho_hosp_ci[0]:.3f} to {rho_hosp_ci[1]:.3f})")
print(f"p = {format_p(p_hosp)}")

print("\nHospital LOS by Risk Category:")
print(f"{'Category':<12} {'N':<8} {'Median (IQR)':<18} {'Mean ± SD':<14}")
print("-"*60)

hosp_los_by_risk = {}
groups_hosp = []
for cat in ['Low', 'Moderate', 'High']:
    s = df_hlos[df_hlos['risk_category'] == cat]['hospital_los_days'].dropna()
    n = int(len(s))
    median = float(s.median()) if n > 0 else np.nan
    q1 = float(s.quantile(0.25)) if n > 0 else np.nan
    q3 = float(s.quantile(0.75)) if n > 0 else np.nan
    mean = float(s.mean()) if n > 0 else np.nan
    std = float(s.std()) if n > 1 else np.nan
    hosp_los_by_risk[cat] = {'n': n, 'median': median, 'iqr': (q1, q3), 'mean': mean, 'std': std}
    groups_hosp.append(s.values)

    med_str = "NA" if np.isnan(median) else f"{median:.1f} ({q1:.1f} to {q3:.1f})"
    mean_str = "NA" if np.isnan(mean) else f"{mean:.1f} ± {std:.1f}"
    print(f"{cat:<12} {n:<8} {med_str:<18} {mean_str:<14}")

nonempty = [g for g in groups_hosp if len(g) > 0]
if len(nonempty) >= 2:
    kw_hosp, p_kw_hosp = kruskal(*nonempty)
    print(f"\nKruskal-Wallis H = {kw_hosp:.1f}, p = {format_p(p_kw_hosp)}")
else:
    kw_hosp, p_kw_hosp = None, None
    print("\nKruskal-Wallis: NA (insufficient group data)")

#------------------------------------------------------------------------------
# SUMMARY TABLE: SECONDARY OUTCOMES
#------------------------------------------------------------------------------
print("\n" + "="*70)
print("SUMMARY TABLE: SECONDARY OUTCOMES")
print("="*70)

print("\nBINARY OUTCOMES:")
print(f"{'Outcome':<24} {'Prevalence':<12} {'AUROC (95% CI)':<22} {'p':<8}")
print("-"*70)

binary_outcomes = [
    ('Vasopressor use', vaso_results),
    ('Inotrope use', ino_results),
    ('MCS use', mcs_results),
    ('Invasive ventilation', vent_results)
]

for name, res in binary_outcomes:
    prev_str = f"{res['prevalence']:.1f}%"
    if res['auroc'] is not None:
        auroc_str = f"{res['auroc']:.3f} ({res['auroc_ci'][0]:.3f} to {res['auroc_ci'][1]:.3f})"
    else:
        auroc_str = "NA"
    p_str = format_p(res['p_value'])
    print(f"{name:<24} {prev_str:<12} {auroc_str:<22} {p_str:<8}")

print("\nCONTINUOUS OUTCOMES:")
print(f"{'Outcome':<24} {'N':<8} {'Spearman ρ (95% CI)':<26} {'p':<8}")
print("-"*70)
print(f"{'ICU length of stay':<24} {len(df_los):<8} {rho_icu:.3f} ({rho_icu_ci[0]:.3f} to {rho_icu_ci[1]:.3f})   {format_p(p_icu):<8}")
print(f"{'Hospital length of stay':<24} {len(df_hlos):<8} {rho_hosp:.3f} ({rho_hosp_ci[0]:.3f} to {rho_hosp_ci[1]:.3f})   {format_p(p_hosp):<8}")

#------------------------------------------------------------------------------
# STORE RESULTS
#------------------------------------------------------------------------------
results_secondary = {
    'vasopressor': vaso_results,
    'inotrope': ino_results,
    'mcs': mcs_results,
    'ventilation': vent_results,
    'icu_los': {
        'n': int(len(df_los)),
        'spearman_rho': float(rho_icu),
        'spearman_ci': tuple(rho_icu_ci),
        'p_value': float(p_icu),
        'kruskal_h': None if kw_icu is None else float(kw_icu),
        'kruskal_p': None if p_kw_icu is None else float(p_kw_icu),
        'by_risk': icu_los_by_risk
    },
    'hospital_los': {
        'n': int(len(df_hlos)),
        'spearman_rho': float(rho_hosp),
        'spearman_ci': tuple(rho_hosp_ci),
        'p_value': float(p_hosp),
        'kruskal_h': None if kw_hosp is None else float(kw_hosp),
        'kruskal_p': None if p_kw_hosp is None else float(p_kw_hosp),
        'by_risk': hosp_los_by_risk
    }
}

print("\n" + "="*70)
print("✓ Secondary outcomes stored in 'results_secondary' dictionary")
print("="*70)
print("\n→ Next: Tables and Figures")


In [None]:
#==========================================================================
# SECTION 12: MAIN MANUSCRIPT TABLES
# Cell 16: Generate Tables 1, 2, and 3
#==========================================================================

from tableone import TableOne
import pandas as pd
import numpy as np
import warnings
from scipy.stats import kruskal, chi2_contingency
warnings.filterwarnings('ignore')

print("="*70)
print("MAIN MANUSCRIPT TABLES")
print("="*70)

# Small helpers
def safe_chi_square(df_in, group_col, outcome_col):
    tab = pd.crosstab(df_in[group_col], df_in[outcome_col])
    if tab.shape[0] < 2 or tab.shape[1] < 2:
        return None
    try:
        chi2, p, dof, exp = chi2_contingency(tab)
        return float(p)
    except Exception:
        return None

def format_p(p):
    if p is None or (isinstance(p, float) and np.isnan(p)):
        return "NA"
    if p < 0.001:
        return "<0.001"
    return f"{p:.3f}"

#==========================================================================
# TABLE 1: BASELINE CHARACTERISTICS BY BAN-ADHF RISK CATEGORY
#==========================================================================

print("\n" + "="*70)
print("TABLE 1: Baseline Characteristics by BAN-ADHF Risk Category")
print("="*70)

df_table1 = df.copy()

# Create male_sex variable
if 'gender' in df_table1.columns:
    if df_table1['gender'].dtype == 'object':
        df_table1['male_sex'] = (df_table1['gender'].astype(str).str.upper().str.strip() == 'M').astype(int)
    else:
        df_table1['male_sex'] = pd.to_numeric(df_table1['gender'], errors='coerce')
else:
    # If gender not present, keep placeholder to avoid crash if someone left it in the column list
    df_table1['male_sex'] = np.nan

# Risk category as string
df_table1['risk_category'] = df_table1['risk_category'].astype(str)

# Convert binary variables to Yes/No for TableOne display
binary_vars_to_fix = [
    'male_sex', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hx_diabetes', 'hx_renal_disease',
    'hx_myocardial_infarction', 'hx_stroke', 'hx_copd',
    'cardiogenic_shock', 'invasive_vent'
]

for var in binary_vars_to_fix:
    if var in df_table1.columns:
        df_table1[var] = pd.to_numeric(df_table1[var], errors='coerce')
        df_table1[var] = df_table1[var].map({1: 'Yes', 0: 'No'})

# Define variables
all_columns = [
    'age', 'male_sex', 'ban_adhf_total_score',
    'creatinine', 'bun', 'ntprobnp', 'dbp',
    'total_furosemide_equivalent_mg',
    'hx_atrial_fibrillation', 'hx_hypertension', 'prior_hf_hospitalization_12mo',
    'lvef', 'hf_phenotype',
    'hx_diabetes', 'hx_renal_disease', 'hx_myocardial_infarction',
    'hx_stroke', 'hx_copd', 'cci_score',
    'cardiogenic_shock', 'invasive_vent'
]
all_columns = [v for v in all_columns if v in df_table1.columns]

categorical_vars = [
    'male_sex', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hf_phenotype',
    'hx_diabetes', 'hx_renal_disease', 'hx_myocardial_infarction',
    'hx_stroke', 'hx_copd', 'cardiogenic_shock', 'invasive_vent'
]
categorical_vars = [v for v in categorical_vars if v in df_table1.columns]

nonnormal_vars = ['ban_adhf_total_score', 'creatinine', 'bun', 'ntprobnp',
                  'total_furosemide_equivalent_mg', 'cci_score']
nonnormal_vars = [v for v in nonnormal_vars if v in df_table1.columns]

labels = {
    'age': 'Age, years',
    'male_sex': 'Male sex',
    'ban_adhf_total_score': 'BAN-ADHF score',
    'creatinine': 'Creatinine, mg/dL',
    'bun': 'BUN, mg/dL',
    'ntprobnp': 'NT-proBNP, pg/mL',
    'dbp': 'Diastolic BP, mmHg',
    'total_furosemide_equivalent_mg': 'Home diuretic dose, mg/day +',
    'hx_atrial_fibrillation': 'Atrial fibrillation',
    'hx_hypertension': 'Hypertension',
    'prior_hf_hospitalization_12mo': 'Prior HF hospitalization (12 mo)',
    'lvef': 'LVEF, % ++',
    'hf_phenotype': 'HF phenotype ++',
    'hx_diabetes': 'Diabetes mellitus',
    'hx_renal_disease': 'Chronic kidney disease',
    'hx_myocardial_infarction': 'Prior myocardial infarction',
    'hx_stroke': 'Prior stroke',
    'hx_copd': 'COPD',
    'cci_score': 'Charlson Comorbidity Index',
    'cardiogenic_shock': 'Cardiogenic shock',
    'invasive_vent': 'Invasive mechanical ventilation'
}

order = {
    'risk_category': ['Low', 'Moderate', 'High'],
    'male_sex': ['Yes', 'No'],
    'hx_atrial_fibrillation': ['Yes', 'No'],
    'hx_hypertension': ['Yes', 'No'],
    'prior_hf_hospitalization_12mo': ['Yes', 'No'],
    'hx_diabetes': ['Yes', 'No'],
    'hx_renal_disease': ['Yes', 'No'],
    'hx_myocardial_infarction': ['Yes', 'No'],
    'hx_stroke': ['Yes', 'No'],
    'hx_copd': ['Yes', 'No'],
    'cardiogenic_shock': ['Yes', 'No'],
    'invasive_vent': ['Yes', 'No']
}

limit = {k: 1 for k in order.keys() if k != 'risk_category'}

table1 = TableOne(
    df_table1,
    columns=all_columns,
    categorical=categorical_vars,
    nonnormal=nonnormal_vars,
    groupby='risk_category',
    pval=True,
    rename=labels,
    missing=False,
    overall=True,
    order=order,
    limit=limit,
    decimals=1
)

print("\n")
print(table1.tabulate(tablefmt="simple"))

print("\n" + "-"*70)
print("+ Home diuretic dose as oral furosemide equivalents")
if 'lvef' in df_table1.columns:
    print("++ LVEF available in", int(df_table1['lvef'].notna().sum()), "patients")

table1.to_csv('/content/Table1_baseline_by_risk.csv')
table1.to_excel('/content/Table1_baseline_by_risk.xlsx')
print("\n✓ Table 1 saved")

# Verification gradients (uses original df, not the Yes/No transformed df_table1)
print("\n" + "-"*70)
print("VERIFICATION: Comorbidity Gradients (should increase with higher risk)")
print("-"*70)

df_verify = df.copy()
for col in ['hx_diabetes', 'hx_renal_disease', 'cardiogenic_shock', 'prior_hf_hospitalization_12mo']:
    if col in df_verify.columns:
        df_verify[col] = pd.to_numeric(df_verify[col], errors='coerce')

for cat in ['Low', 'Moderate', 'High']:
    subset = df_verify[df_verify['risk_category'] == cat]
    n = len(subset)
    dm_pct = 100 * subset['hx_diabetes'].mean() if 'hx_diabetes' in subset.columns else np.nan
    ckd_pct = 100 * subset['hx_renal_disease'].mean() if 'hx_renal_disease' in subset.columns else np.nan
    cs_pct = 100 * subset['cardiogenic_shock'].mean() if 'cardiogenic_shock' in subset.columns else np.nan
    hf_hosp = 100 * subset['prior_hf_hospitalization_12mo'].mean() if 'prior_hf_hospitalization_12mo' in subset.columns else np.nan
    print(f"  {cat}: DM={dm_pct:.1f}%, CKD={ckd_pct:.1f}%, CS={cs_pct:.1f}%, Prior HF hosp={hf_hosp:.1f}%")

#==========================================================================
# CALCULATE EFFICIENCY BY RISK FOR TABLE 2
#==========================================================================

print("\n" + "-"*70)
print("Calculating efficiency distributions by risk category")
print("-"*70)

df_24h = df[(df['icu_stay_ge_24h'] == 1) &
            (df['diuretic_efficiency_24h'].notna()) &
            (df['diuretic_efficiency_24h'] > 0)].copy()

eff_24h_by_risk = {}
for cat in ['Low', 'Moderate', 'High']:
    subset = df_24h[df_24h['risk_category'] == cat]['diuretic_efficiency_24h']
    eff_24h_by_risk[cat] = {
        'n': int(len(subset)),
        'median': float(subset.median()),
        'q1': float(subset.quantile(0.25)),
        'q3': float(subset.quantile(0.75))
    }
    print(f"  24h {cat}: N={eff_24h_by_risk[cat]['n']}, Median={eff_24h_by_risk[cat]['median']:.1f} [{eff_24h_by_risk[cat]['q1']:.1f} to {eff_24h_by_risk[cat]['q3']:.1f}]")

df_72h = df[(df['icu_stay_ge_72h'] == 1) &
            (df['diuretic_efficiency_72h'].notna()) &
            (df['diuretic_efficiency_72h'] > 0)].copy()

eff_72h_by_risk = {}
for cat in ['Low', 'Moderate', 'High']:
    subset = df_72h[df_72h['risk_category'] == cat]['diuretic_efficiency_72h']
    eff_72h_by_risk[cat] = {
        'n': int(len(subset)),
        'median': float(subset.median()),
        'q1': float(subset.quantile(0.25)),
        'q3': float(subset.quantile(0.75))
    }
    print(f"  72h {cat}: N={eff_72h_by_risk[cat]['n']}, Median={eff_72h_by_risk[cat]['median']:.1f} [{eff_72h_by_risk[cat]['q1']:.1f} to {eff_72h_by_risk[cat]['q3']:.1f}]")

# Kruskal-Wallis p across risk categories for efficiency rows
groups_24h = [df_24h[df_24h['risk_category'] == cat]['diuretic_efficiency_24h'].values for cat in ['Low', 'Moderate', 'High']]
groups_24h = [g for g in groups_24h if len(g) > 0]
p_kw_24h = kruskal(*groups_24h).pvalue if len(groups_24h) >= 2 else None

groups_72h = [df_72h[df_72h['risk_category'] == cat]['diuretic_efficiency_72h'].values for cat in ['Low', 'Moderate', 'High']]
groups_72h = [g for g in groups_72h if len(g) > 0]
p_kw_72h = kruskal(*groups_72h).pvalue if len(groups_72h) >= 2 else None

#==========================================================================
# TABLE 2: PRIMARY OUTCOMES AND DISCRIMINATION SUMMARY
#==========================================================================

print("\n" + "="*70)
print("TABLE 2: Primary Outcomes and Discrimination Summary")
print("="*70)

table2_rows = []

# 24-Hour Diuretic Efficiency
table2_rows.append(['24-Hour Diuretic Efficiency', 'N', f"{results_24h['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Spearman rho', f"{results_24h['spearman_rho']:.3f}",
                    f"({results_24h['spearman_ci'][0]:.3f}, {results_24h['spearman_ci'][1]:.3f})", "<0.001"])
table2_rows.append(['', 'Pearson r', f"{results_24h['pearson_r']:.3f}",
                    f"({results_24h['pearson_ci'][0]:.3f}, {results_24h['pearson_ci'][1]:.3f})", "<0.001"])
table2_rows.append(['', 'C-index', f"{results_24h['c_index']:.3f}",
                    f"({results_24h['c_index_ci'][0]:.3f}, {results_24h['c_index_ci'][1]:.3f})", "NA"])
table2_rows.append(['', 'AUROC (Q20)', f"{results_24h['auroc_quintile']:.3f}",
                    f"({results_24h['auroc_quintile_ci'][0]:.3f}, {results_24h['auroc_quintile_ci'][1]:.3f})", "NA"])
table2_rows.append(['', 'AUROC (Q25)', f"{results_24h['auroc_quartile']:.3f}",
                    f"({results_24h['auroc_quartile_ci'][0]:.3f}, {results_24h['auroc_quartile_ci'][1]:.3f})", "NA"])
table2_rows.append(['', 'Efficiency by risk (median [Q1 to Q3])', '', '', f"{format_p(p_kw_24h)} *"])

for cat in ['Low', 'Moderate', 'High']:
    e = eff_24h_by_risk[cat]
    table2_rows.append(['', f'  {cat}, mL/mg', f"{e['median']:.1f}", f"[{e['q1']:.1f} to {e['q3']:.1f}]", ''])

# 72-Hour Diuretic Efficiency
table2_rows.append(['72-Hour Diuretic Efficiency', 'N', f"{results_72h['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Spearman rho', f"{results_72h['spearman_rho']:.3f}",
                    f"({results_72h['spearman_ci'][0]:.3f}, {results_72h['spearman_ci'][1]:.3f})", "<0.001"])
table2_rows.append(['', 'AUROC (Q20)', f"{results_72h['auroc_quintile']:.3f}",
                    f"({results_72h['auroc_quintile_ci'][0]:.3f}, {results_72h['auroc_quintile_ci'][1]:.3f})", "NA"])
table2_rows.append(['', 'Efficiency by risk (median [Q1 to Q3])', '', '', f"{format_p(p_kw_72h)} *"])

for cat in ['Low', 'Moderate', 'High']:
    e = eff_72h_by_risk[cat]
    table2_rows.append(['', f'  {cat}, mL/mg', f"{e['median']:.1f}", f"[{e['q1']:.1f} to {e['q3']:.1f}]", ''])

# Diuretic Resistance
table2_rows.append(['Diuretic Resistance +', 'N', f"{results_dr['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Prevalence', f"{results_dr['prevalence']:.1f}%", 'NA', 'NA'])
table2_rows.append(['', 'AUROC', f"{results_dr['auroc']:.3f}",
                    f"({results_dr['auroc_ci'][0]:.3f}, {results_dr['auroc_ci'][1]:.3f})", "NA"])

p_dr_risk = None
if 'dr_by_risk' in results_dr and 'risk_category' in df.columns and 'diuretic_resistance' in df.columns:
    df_tmp = df[['risk_category', 'diuretic_resistance']].dropna().copy()
    df_tmp['diuretic_resistance'] = pd.to_numeric(df_tmp['diuretic_resistance'], errors='coerce')
    df_tmp = df_tmp[df_tmp['diuretic_resistance'].notna()]
    df_tmp['diuretic_resistance'] = (df_tmp['diuretic_resistance'] > 0).astype(int)
    p_dr_risk = safe_chi_square(df_tmp, 'risk_category', 'diuretic_resistance')

table2_rows.append(['', 'DR rate by risk', '', '', f"{format_p(p_dr_risk)} *"])

for cat in ['Low', 'Moderate', 'High']:
    rate = results_dr['dr_by_risk'][cat]['rate']
    table2_rows.append(['', f'  {cat}', f"{rate:.1f}%", 'NA', ''])

# Mortality
table2_rows.append(['In-Hospital Mortality ++', 'N', f"{results_mortality['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Mortality rate', f"{results_mortality['mortality_rate']:.1f}%", 'NA', 'NA'])
table2_rows.append(['', 'AUROC', f"{results_mortality['auroc']:.3f}",
                    f"({results_mortality['auroc_ci'][0]:.3f}, {results_mortality['auroc_ci'][1]:.3f})", "NA"])
table2_rows.append(['', 'AUROC (CS subgroup)', f"{results_mortality['auroc_cs']:.3f}", 'NA', 'NA'])

p_mort_risk = None
if 'mort_by_risk' in results_mortality and 'risk_category' in df.columns and 'in_hospital_mortality' in df.columns:
    df_tmp = df[['risk_category', 'in_hospital_mortality']].dropna().copy()
    df_tmp['in_hospital_mortality'] = pd.to_numeric(df_tmp['in_hospital_mortality'], errors='coerce')
    df_tmp = df_tmp[df_tmp['in_hospital_mortality'].notna()]
    df_tmp['in_hospital_mortality'] = (df_tmp['in_hospital_mortality'] > 0).astype(int)
    p_mort_risk = safe_chi_square(df_tmp, 'risk_category', 'in_hospital_mortality')

table2_rows.append(['', 'Mortality by risk', '', '', f"{format_p(p_mort_risk)} *"])

for cat in ['Low', 'Moderate', 'High']:
    rate = results_mortality['mort_by_risk'][cat]['rate']
    table2_rows.append(['', f'  {cat}', f"{rate:.1f}%", 'NA', ''])

table2_df = pd.DataFrame(table2_rows, columns=['Outcome', 'Metric', 'Value', '95% CI', 'p-value'])

print("\n")
print(table2_df.to_string(index=False))
print("\n" + "-"*70)
print("* Kruskal-Wallis (efficiency) or Chi-square (binary) across risk categories")
print("+ Urine output <=3,000 mL in first 24 hours")
print("++ Exploratory outcome. BAN-ADHF designed for diuretic efficiency")

table2_df.to_csv('/content/Table2_outcomes_discrimination.csv', index=False)
table2_df.to_excel('/content/Table2_outcomes_discrimination.xlsx', index=False)
print("\n✓ Table 2 saved")

#==========================================================================
# TABLE 3: LITERATURE COMPARISON
#==========================================================================

print("\n" + "="*70)
print("TABLE 3: Literature Comparison")
print("="*70)

delta_pearson = ((abs(results_24h['pearson_r']) - 0.40) / 0.40) * 100
delta_q20 = ((results_24h['auroc_quintile'] - 0.84) / 0.84) * 100
delta_q25 = ((results_24h['auroc_quartile'] - 0.70) / 0.70) * 100
delta_dr = ((results_dr['auroc'] - 0.631) / 0.631) * 100

table3_rows = [
    ['Segar 2024', 'DOSE/ESCAPE', '707', 'NA', 'NA', '0.84', 'NA', 'NA', 'External validation'],
    ['Pandey 2025', 'CLOROTIC', '220', 'NA', 'B=-0.18/pt +', 'NA', '0.70', 'NA', 'Trial validation'],
    ['Mauch 2025', 'Floor patients', '317', 'NA', 'r=-0.40', 'NA', 'NA', '0.631', 'Real-world'],
    ['This Study', 'ICU ADHF', f"{results_24h['n']}", f"{results_24h['spearman_rho']:.3f}",
     f"r={results_24h['pearson_r']:.3f}", f"{results_24h['auroc_quintile']:.3f}",
     f"{results_24h['auroc_quartile']:.3f}", f"{results_dr['auroc']:.3f}", 'ICU validation'],
]

table3_df = pd.DataFrame(
    table3_rows,
    columns=['Study', 'Population', 'N', 'Spearman rho', 'Association', 'AUROC Q20', 'AUROC Q25', 'Binary DR', 'Notes']
)

print("\n")
print(table3_df.to_string(index=False))

print("\n" + "-"*70)
print("COMPARISON VS REFERENCES")
print("-"*70)
print(f"  Pearson r:  This study r={results_24h['pearson_r']:.3f} vs Mauch r=-0.40  -> {delta_pearson:+.1f}%")
print(f"  AUROC Q20:  This study {results_24h['auroc_quintile']:.3f} vs Segar 0.84  -> {delta_q20:+.1f}%")
print(f"  AUROC Q25:  This study {results_24h['auroc_quartile']:.3f} vs Pandey 0.70 -> {delta_q25:+.1f}%")
print(f"  Binary DR:  This study {results_dr['auroc']:.3f} vs Mauch 0.631 -> {delta_dr:+.1f}%")

print("\n" + "-"*70)
print("NOTES")
print("-"*70)
print("+ Pandey reports a beta coefficient (regression slope), not a correlation")
print("- Segar external validation: DOSE/ESCAPE, AUROC Q20=0.84")
print("- Mauch: r=-0.40 for 72h urine diuretic efficiency. Binary DR AUROC 0.631 for 24h urine output goal")
print("- This study Binary DR: UOP <=3000 mL, prevalence 76.3%")
print("- Mauch Binary DR prevalence is much lower, which can inflate AUROC relative to a high-prevalence ICU cohort")

table3_df.to_csv('/content/Table3_literature_comparison.csv', index=False)
table3_df.to_excel('/content/Table3_literature_comparison.xlsx', index=False)
print("\n✓ Table 3 saved")

#==========================================================================
# KEY FINDINGS SUMMARY FOR MANUSCRIPT
#==========================================================================

print("\n" + "="*70)
print("KEY FINDINGS SUMMARY FOR MANUSCRIPT")
print("="*70)

print(f"""
PRIMARY OUTCOME. 24h Diuretic Efficiency
  Spearman rho = {results_24h['spearman_rho']:.3f} (95% CI: {results_24h['spearman_ci'][0]:.3f}, {results_24h['spearman_ci'][1]:.3f}), p<0.001
  Pearson r    = {results_24h['pearson_r']:.3f} (95% CI: {results_24h['pearson_ci'][0]:.3f}, {results_24h['pearson_ci'][1]:.3f}), p<0.001
  AUROC (Q20)  = {results_24h['auroc_quintile']:.3f} (95% CI: {results_24h['auroc_quintile_ci'][0]:.3f}, {results_24h['auroc_quintile_ci'][1]:.3f})

RISK STRATIFICATION (examples)
  Low:      Efficiency {eff_24h_by_risk['Low']['median']:.1f} mL/mg
  Moderate: Efficiency {eff_24h_by_risk['Moderate']['median']:.1f} mL/mg
  High:     Efficiency {eff_24h_by_risk['High']['median']:.1f} mL/mg

CONCLUSION
  BAN-ADHF validates in a critically ill ICU cohort with preserved continuous discrimination.
  Binary discrimination is appropriately attenuated in a high-prevalence ICU setting.
""")

#==========================================================================
# R DATA EXPORTS
#==========================================================================

print("\n" + "="*70)
print("DATA EXPORTS FOR R")
print("="*70)

# Forest plot data
# Fix: allow either results_subgroups['forest_data'] or results_subgroups itself containing the dataframe
forest_df = None
if isinstance(results_subgroups, dict):
    forest_df = results_subgroups.get('forest_data', None)

# If forest_df not present but you have a variable named forest_data in memory, use it
if forest_df is None and 'forest_data' in globals():
    forest_df = globals()['forest_data']

if forest_df is None:
    print("Forest plot export skipped. Could not find forest_data dataframe.")
else:
    forest_df.to_csv('/content/data_for_R_forest_plot.csv', index=False)
    print("✓ data_for_R_forest_plot.csv")

# Outcomes by risk
outcomes_risk = []
for cat in ['Low', 'Moderate', 'High']:
    outcomes_risk.append({
        'risk_category': cat,
        'eff_24h_median': eff_24h_by_risk[cat]['median'],
        'eff_24h_q1': eff_24h_by_risk[cat]['q1'],
        'eff_24h_q3': eff_24h_by_risk[cat]['q3'],
        'eff_72h_median': eff_72h_by_risk[cat]['median'],
        'eff_72h_q1': eff_72h_by_risk[cat]['q1'],
        'eff_72h_q3': eff_72h_by_risk[cat]['q3'],
        'dr_rate': results_dr['dr_by_risk'][cat]['rate'],
        'mortality_rate': results_mortality['mort_by_risk'][cat]['rate']
    })
pd.DataFrame(outcomes_risk).to_csv('/content/data_for_R_outcomes_by_risk.csv', index=False)
print("✓ data_for_R_outcomes_by_risk.csv")

# Raw efficiency data for box plots
df_eff_r = df[['hadm_id', 'ban_adhf_total_score', 'risk_category',
               'diuretic_efficiency_24h', 'diuretic_efficiency_72h']].copy()
df_eff_r.to_csv('/content/data_for_R_efficiency_raw.csv', index=False)
print("✓ data_for_R_efficiency_raw.csv")

# Literature comparison data for plotting
lit_compare = pd.DataFrame({
    'study': ['Segar 2024 (derivation)', 'Segar 2024 (validation)', 'Pandey 2025', 'Mauch 2025', 'This Study'],
    'population': ['ROSE/CARRESS/ATHENA', 'DOSE/ESCAPE', 'CLOROTIC', 'Floor', 'ICU'],
    'n': [794, 707, 220, 317, results_24h['n']],
    'pearson_r': [np.nan, np.nan, -0.18, -0.40, results_24h['pearson_r']],
    'auroc_q20': [0.87, 0.84, np.nan, np.nan, results_24h['auroc_quintile']],
    'auroc_q25': [np.nan, np.nan, 0.70, np.nan, results_24h['auroc_quartile']],
    'auroc_dr': [np.nan, np.nan, np.nan, 0.631, results_dr['auroc']]
})
lit_compare.to_csv('/content/data_for_R_literature_comparison.csv', index=False)
print("✓ data_for_R_literature_comparison.csv")

print("\n" + "="*70)
print("✓ ALL MAIN TABLES COMPLETE")
print("="*70)
print("""
Files generated:
  - Table1_baseline_by_risk.csv/.xlsx
  - Table2_outcomes_discrimination.csv/.xlsx
  - Table3_literature_comparison.csv/.xlsx
  - data_for_R_forest_plot.csv (if available)
  - data_for_R_outcomes_by_risk.csv
  - data_for_R_efficiency_raw.csv
  - data_for_R_literature_comparison.csv
""")
print("\n-> Next: Supplementary Tables (Cell 17)")


In [None]:
#==========================================================================
# SECTION 12: MAIN MANUSCRIPT TABLES
# Cell 16: Generate Tables 1, 2, and 3 (CORRECTED C-INDEX)
#==========================================================================

import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
from tableone import TableOne

from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import roc_auc_score

print("="*70)
print("MAIN MANUSCRIPT TABLES")
print("="*70)

# -------------------------------------------------------------------------
# Bootstrap setup (used for Table 2 primary metrics)
# -------------------------------------------------------------------------
np.random.seed(42)
n_bootstrap = 1000

def bootstrap_ci(values, alpha=0.05):
    lo = np.percentile(values, 100 * (alpha/2))
    hi = np.percentile(values, 100 * (1 - alpha/2))
    return (lo, hi)

def safe_roc_auc(y_true, y_score):
    y_true = np.asarray(y_true)
    if y_true.sum() == 0 or y_true.sum() == len(y_true):
        return None
    return roc_auc_score(y_true, y_score)

def concordance_index_fallback(x, y):
    """
    Harrell C-index for continuous outcome.
    Higher BAN-ADHF score (x) should map to lower efficiency (y).
    So compute c-index(efficiency, -score) = c-index(y, -x).

    CORRECTED: Arguments now match Cell 9's approach.
    """
    try:
        from lifelines.utils import concordance_index
        return float(concordance_index(y, -x))  # CORRECTED: was (x, -y)
    except Exception:
        x = np.asarray(x)
        y = np.asarray(y)

        n = len(x)
        concordant = 0
        permissible = 0
        ties = 0

        for i in range(n):
            for j in range(i + 1, n):
                if y[i] == y[j]:
                    continue
                permissible += 1
                # Higher score should predict lower efficiency
                # Concordant if: (y[i] > y[j] and x[i] < x[j]) or (y[i] < y[j] and x[i] > x[j])
                dy = y[i] - y[j]
                dx = x[i] - x[j]
                if dx == 0:
                    ties += 1
                elif dy * dx < 0:  # CORRECTED: was dy * dx > 0
                    concordant += 1

        if permissible == 0:
            return np.nan

        return (concordant + 0.5 * ties) / permissible

def compute_primary_eff_metrics(df_in, score_col, eff_col, q_list=(0.20, 0.25), n_boot=1000):
    dfv = df_in[[score_col, eff_col]].dropna().copy()
    dfv = dfv[dfv[eff_col] > 0].copy()
    n = len(dfv)

    x = dfv[score_col].astype(float).values
    y = dfv[eff_col].astype(float).values

    # Spearman + bootstrap CI
    rho, p_rho = spearmanr(x, y)
    rho_boot = []
    for _ in range(n_boot):
        idx = np.random.choice(n, size=n, replace=True)
        r, _ = spearmanr(x[idx], y[idx])
        rho_boot.append(r)
    rho_ci = bootstrap_ci(rho_boot)

    # Pearson + bootstrap CI
    r_lin, p_lin = pearsonr(x, y)
    r_boot = []
    for _ in range(n_boot):
        idx = np.random.choice(n, size=n, replace=True)
        rr, _ = pearsonr(x[idx], y[idx])
        r_boot.append(rr)
    r_ci = bootstrap_ci(r_boot)

    # C-index + bootstrap CI
    c_index = concordance_index_fallback(x, y)
    c_boot = []
    for _ in range(n_boot):
        idx = np.random.choice(n, size=n, replace=True)
        c_boot.append(concordance_index_fallback(x[idx], y[idx]))
    c_ci = bootstrap_ci(c_boot)

    # AUROC(s) for low efficiency by quantile thresholds
    aurocs = {}
    for q in q_list:
        thr = float(np.quantile(y, q))
        low = (y <= thr).astype(int)
        auc = safe_roc_auc(low, x)

        auc_boot = []
        if auc is not None:
            for _ in range(n_boot):
                idx = np.random.choice(n, size=n, replace=True)
                low_b = low[idx]
                x_b = x[idx]
                auc_b = safe_roc_auc(low_b, x_b)
                if auc_b is not None:
                    auc_boot.append(auc_b)
            auc_ci = bootstrap_ci(auc_boot) if len(auc_boot) > 0 else (np.nan, np.nan)
        else:
            auc_ci = (np.nan, np.nan)

        aurocs[q] = {
            "threshold": thr,
            "n_events": int(low.sum()),
            "auroc": float(auc) if auc is not None else np.nan,
            "auroc_ci": (float(auc_ci[0]), float(auc_ci[1]))
        }

    return {
        "n": int(n),
        "spearman_rho": float(rho),
        "spearman_p": float(p_rho),
        "spearman_ci": (float(rho_ci[0]), float(rho_ci[1])),
        "pearson_r": float(r_lin),
        "pearson_p": float(p_lin),
        "pearson_ci": (float(r_ci[0]), float(r_ci[1])),
        "c_index": float(c_index),
        "c_index_ci": (float(c_ci[0]), float(c_ci[1])),
        "aurocs": aurocs
    }

def format_ci_tuple(ci, decimals=3):
    return f"({ci[0]:.{decimals}f}, {ci[1]:.{decimals}f})"


#==========================================================================
# TABLE 1: BASELINE CHARACTERISTICS BY BAN-ADHF RISK CATEGORY
#==========================================================================

print("\n" + "="*70)
print("TABLE 1: Baseline Characteristics by BAN-ADHF Risk Category")
print("="*70)

df_table1 = df.copy()

# Create male_sex variable robustly
if 'gender' in df_table1.columns:
    if df_table1['gender'].dtype == 'object':
        g = df_table1['gender'].astype(str).str.strip().str.upper()
        df_table1['male_sex'] = g.isin(['M', 'MALE', 'MAN']).astype(int)
    else:
        df_table1['male_sex'] = pd.to_numeric(df_table1['gender'], errors='coerce')
else:
    # If you ever run without gender, create missing column to avoid TableOne breaking
    df_table1['male_sex'] = np.nan

df_table1['risk_category'] = df_table1['risk_category'].astype(str)

# Convert binary variables to Yes/No for TableOne
binary_vars_to_fix = [
    'male_sex', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hx_diabetes', 'hx_renal_disease',
    'hx_myocardial_infarction', 'hx_stroke', 'hx_copd',
    'cardiogenic_shock', 'invasive_vent'
]

for var in binary_vars_to_fix:
    if var in df_table1.columns:
        df_table1[var] = df_table1[var].map({1: 'Yes', 0: 'No', True: 'Yes', False: 'No'})

all_columns = [
    'age', 'male_sex', 'ban_adhf_total_score',
    'creatinine', 'bun', 'ntprobnp', 'dbp',
    'total_furosemide_equivalent_mg',
    'hx_atrial_fibrillation', 'hx_hypertension', 'prior_hf_hospitalization_12mo',
    'lvef', 'hf_phenotype',
    'hx_diabetes', 'hx_renal_disease', 'hx_myocardial_infarction',
    'hx_stroke', 'hx_copd', 'cci_score',
    'cardiogenic_shock', 'invasive_vent'
]
all_columns = [v for v in all_columns if v in df_table1.columns]

categorical_vars = [
    'male_sex', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hf_phenotype',
    'hx_diabetes', 'hx_renal_disease', 'hx_myocardial_infarction',
    'hx_stroke', 'hx_copd', 'cardiogenic_shock', 'invasive_vent'
]
categorical_vars = [v for v in categorical_vars if v in df_table1.columns]

nonnormal_vars = [
    'ban_adhf_total_score', 'creatinine', 'bun', 'ntprobnp',
    'total_furosemide_equivalent_mg', 'cci_score'
]
nonnormal_vars = [v for v in nonnormal_vars if v in df_table1.columns]

labels = {
    'age': 'Age, years',
    'male_sex': 'Male sex',
    'ban_adhf_total_score': 'BAN-ADHF score',
    'creatinine': 'Creatinine, mg/dL',
    'bun': 'BUN, mg/dL',
    'ntprobnp': 'NT-proBNP, pg/mL',
    'dbp': 'Diastolic BP, mmHg',
    'total_furosemide_equivalent_mg': 'Home diuretic dose, mg/day +',
    'hx_atrial_fibrillation': 'Atrial fibrillation',
    'hx_hypertension': 'Hypertension',
    'prior_hf_hospitalization_12mo': 'Prior HF hospitalization (12 mo)',
    'lvef': 'LVEF, % ++',
    'hf_phenotype': 'HF phenotype ++',
    'hx_diabetes': 'Diabetes mellitus',
    'hx_renal_disease': 'Chronic kidney disease',
    'hx_myocardial_infarction': 'Prior myocardial infarction',
    'hx_stroke': 'Prior stroke',
    'hx_copd': 'COPD',
    'cci_score': 'Charlson Comorbidity Index',
    'cardiogenic_shock': 'Cardiogenic shock',
    'invasive_vent': 'Invasive mechanical ventilation'
}

order = {
    'risk_category': ['Low', 'Moderate', 'High'],
    'male_sex': ['Yes', 'No'],
    'hx_atrial_fibrillation': ['Yes', 'No'],
    'hx_hypertension': ['Yes', 'No'],
    'prior_hf_hospitalization_12mo': ['Yes', 'No'],
    'hx_diabetes': ['Yes', 'No'],
    'hx_renal_disease': ['Yes', 'No'],
    'hx_myocardial_infarction': ['Yes', 'No'],
    'hx_stroke': ['Yes', 'No'],
    'hx_copd': ['Yes', 'No'],
    'cardiogenic_shock': ['Yes', 'No'],
    'invasive_vent': ['Yes', 'No']
}

limit = {k: 1 for k in [
    'male_sex', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hx_diabetes', 'hx_renal_disease',
    'hx_myocardial_infarction', 'hx_stroke', 'hx_copd',
    'cardiogenic_shock', 'invasive_vent'
] if k in df_table1.columns}

table1 = TableOne(
    df_table1,
    columns=all_columns,
    categorical=categorical_vars,
    nonnormal=nonnormal_vars,
    groupby='risk_category',
    pval=True,
    rename=labels,
    missing=False,
    overall=True,
    order=order,
    limit=limit,
    decimals=1
)

print("\n")
print(table1.tabulate(tablefmt="simple"))

print("\n" + "-"*70)
print("+ Home diuretic dose as oral furosemide equivalents")
if 'lvef' in df_table1.columns:
    print("++ LVEF available in", int(df_table1['lvef'].notna().sum()), "patients")

table1.to_csv('/content/Table1_baseline_by_risk.csv')
table1.to_excel('/content/Table1_baseline_by_risk.xlsx')
print("\n✓ Table 1 saved")

# Verification gradients (works only if original columns are numeric 0/1 in df)
print("\n" + "-"*70)
print("VERIFICATION: Comorbidity Gradients (should increase with higher risk)")
print("-"*70)

df_verify = df.copy()
for cat in ['Low', 'Moderate', 'High']:
    subset = df_verify[df_verify['risk_category'] == cat]
    n = len(subset)

    def safe_mean(col):
        if col not in subset.columns:
            return np.nan
        return pd.to_numeric(subset[col], errors='coerce').mean()

    dm_pct = 100 * safe_mean('hx_diabetes')
    ckd_pct = 100 * safe_mean('hx_renal_disease')
    cs_pct = 100 * safe_mean('cardiogenic_shock')
    hf_hosp = 100 * safe_mean('prior_hf_hospitalization_12mo')
    print(f"  {cat}: N={n}, DM={dm_pct:.1f}%, CKD={ckd_pct:.1f}%, CS={cs_pct:.1f}%, Prior HF hosp={hf_hosp:.1f}%")

#==========================================================================
# CALCULATE EFFICIENCY BY RISK FOR TABLE 2
#==========================================================================

print("\n" + "-"*70)
print("Calculating efficiency distributions by risk category...")
print("-"*70)

df_24h = df[(df['icu_stay_ge_24h'] == 1) &
            (df['diuretic_efficiency_24h'].notna()) &
            (df['diuretic_efficiency_24h'] > 0)].copy()

eff_24h_by_risk = {}
for cat in ['Low', 'Moderate', 'High']:
    subset = df_24h[df_24h['risk_category'] == cat]['diuretic_efficiency_24h']
    eff_24h_by_risk[cat] = {'n': len(subset), 'median': subset.median(),
                           'q1': subset.quantile(0.25), 'q3': subset.quantile(0.75)}
    print(f"  24h {cat}: N={eff_24h_by_risk[cat]['n']}, Median={eff_24h_by_risk[cat]['median']:.1f} [{eff_24h_by_risk[cat]['q1']:.1f}-{eff_24h_by_risk[cat]['q3']:.1f}]")

df_72h = df[(df['icu_stay_ge_72h'] == 1) &
            (df['diuretic_efficiency_72h'].notna()) &
            (df['diuretic_efficiency_72h'] > 0)].copy()

eff_72h_by_risk = {}
for cat in ['Low', 'Moderate', 'High']:
    subset = df_72h[df_72h['risk_category'] == cat]['diuretic_efficiency_72h']
    eff_72h_by_risk[cat] = {'n': len(subset), 'median': subset.median(),
                           'q1': subset.quantile(0.25), 'q3': subset.quantile(0.75)}
    print(f"  72h {cat}: N={eff_72h_by_risk[cat]['n']}, Median={eff_72h_by_risk[cat]['median']:.1f} [{eff_72h_by_risk[cat]['q1']:.1f}-{eff_72h_by_risk[cat]['q3']:.1f}]")


#==========================================================================
# TABLE 2: PRIMARY OUTCOMES AND DISCRIMINATION SUMMARY (CORRECTED)
#==========================================================================

print("\n" + "="*70)
print("TABLE 2: Primary Outcomes and Discrimination Summary")
print("="*70)

# Compute primary discrimination fresh here.
# This avoids KeyError from results_24h being overwritten elsewhere.
results_24h_clean = compute_primary_eff_metrics(
    df_24h,
    score_col='ban_adhf_total_score',
    eff_col='diuretic_efficiency_24h',
    q_list=(0.20, 0.25),
    n_boot=n_bootstrap
)

results_72h_clean = compute_primary_eff_metrics(
    df_72h,
    score_col='ban_adhf_total_score',
    eff_col='diuretic_efficiency_72h',
    q_list=(0.20,),
    n_boot=n_bootstrap
)

# Pull AUROCs from the clean dict
auroc24_q20 = results_24h_clean["aurocs"][0.20]["auroc"]
auroc24_q20_ci = results_24h_clean["aurocs"][0.20]["auroc_ci"]
auroc24_q25 = results_24h_clean["aurocs"][0.25]["auroc"]
auroc24_q25_ci = results_24h_clean["aurocs"][0.25]["auroc_ci"]

auroc72_q20 = results_72h_clean["aurocs"][0.20]["auroc"]
auroc72_q20_ci = results_72h_clean["aurocs"][0.20]["auroc_ci"]

table2_rows = []

# 24-Hour Diuretic Efficiency
table2_rows.append(['24-Hour Diuretic Efficiency', 'N', f"{results_24h_clean['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Spearman rho', f"{results_24h_clean['spearman_rho']:.3f}",
                    format_ci_tuple(results_24h_clean['spearman_ci']), '<0.001'])
table2_rows.append(['', 'Pearson r', f"{results_24h_clean['pearson_r']:.3f}",
                    format_ci_tuple(results_24h_clean['pearson_ci']), '<0.001'])
table2_rows.append(['', 'C-index', f"{results_24h_clean['c_index']:.3f}",
                    format_ci_tuple(results_24h_clean['c_index_ci']), 'NA'])
table2_rows.append(['', 'AUROC (Q20)', f"{auroc24_q20:.3f}",
                    format_ci_tuple(auroc24_q20_ci), 'NA'])
table2_rows.append(['', 'AUROC (Q25)', f"{auroc24_q25:.3f}",
                    format_ci_tuple(auroc24_q25_ci), 'NA'])

for cat in ['Low', 'Moderate', 'High']:
    e = eff_24h_by_risk[cat]
    p_val = '<0.001 *' if cat == 'High' else ''
    table2_rows.append(['', f'  {cat}, mL/mg', f"{e['median']:.1f}",
                        f"[{e['q1']:.1f}-{e['q3']:.1f}]", p_val])

# 72-Hour Diuretic Efficiency
table2_rows.append(['72-Hour Diuretic Efficiency', 'N', f"{results_72h_clean['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Spearman rho', f"{results_72h_clean['spearman_rho']:.3f}",
                    format_ci_tuple(results_72h_clean['spearman_ci']), '<0.001'])
table2_rows.append(['', 'AUROC (Q20)', f"{auroc72_q20:.3f}",
                    format_ci_tuple(auroc72_q20_ci), 'NA'])

for cat in ['Low', 'Moderate', 'High']:
    e = eff_72h_by_risk[cat]
    p_val = '<0.001 *' if cat == 'High' else ''
    table2_rows.append(['', f'  {cat}, mL/mg', f"{e['median']:.1f}",
                        f"[{e['q1']:.1f}-{e['q3']:.1f}]", p_val])

# Diuretic Resistance (uses your existing results_dr dict, which has the right keys)
table2_rows.append(['Diuretic Resistance +', 'N', f"{results_dr['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Prevalence', f"{results_dr['prevalence']:.1f}%", 'NA', 'NA'])
table2_rows.append(['', 'AUROC', f"{results_dr['auroc']:.3f}",
                    format_ci_tuple(results_dr['auroc_ci']), 'NA'])

for cat in ['Low', 'Moderate', 'High']:
    rate = results_dr['dr_by_risk'][cat]['rate']
    p_val = '<0.001 *' if cat == 'High' else ''
    table2_rows.append(['', f'  {cat}', f"{rate:.1f}%", 'NA', p_val])

# Mortality (uses your existing results_mortality dict, which has the right keys)
table2_rows.append(['In-Hospital Mortality ++', 'N', f"{results_mortality['n']:,}", 'NA', 'NA'])
table2_rows.append(['', 'Mortality rate', f"{results_mortality['mortality_rate']:.1f}%", 'NA', 'NA'])
table2_rows.append(['', 'AUROC', f"{results_mortality['auroc']:.3f}",
                    format_ci_tuple(results_mortality['auroc_ci']), 'NA'])
table2_rows.append(['', 'AUROC (CS subgroup)', f"{results_mortality['auroc_cs']:.3f}",
                    format_ci_tuple(results_mortality['auroc_cs_ci']), 'NA'])

for cat in ['Low', 'Moderate', 'High']:
    rate = results_mortality['mort_by_risk'][cat]['rate']
    p_val = '0.001 *' if cat == 'High' else ''
    table2_rows.append(['', f'  {cat}', f"{rate:.1f}%", 'NA', p_val])

table2_df = pd.DataFrame(table2_rows, columns=['Outcome', 'Metric', 'Value', '95% CI', 'p-value'])

print("\n")
print(table2_df.to_string(index=False))
print("\n" + "-"*70)
print("* Kruskal-Wallis or Chi-square test across risk categories")
print("+ Urine output <=3000 mL in first 24 hours")
print("++ Exploratory outcome. BAN-ADHF designed for diuretic efficiency")

table2_df.to_csv('/content/Table2_outcomes_discrimination.csv', index=False)
table2_df.to_excel('/content/Table2_outcomes_discrimination.xlsx', index=False)
print("\n✓ Table 2 saved")


#==========================================================================
# TABLE 3: LITERATURE COMPARISON (UPDATED TO USE results_24h_clean)
#==========================================================================

print("\n" + "="*70)
print("TABLE 3: Literature Comparison")
print("="*70)

# Reference values used in your prior draft
ref_r = 0.40
ref_q20 = 0.84
ref_q25 = 0.70
ref_dr = 0.631

delta_pearson = ((abs(results_24h_clean['pearson_r']) - ref_r) / ref_r) * 100
delta_q20 = ((auroc24_q20 - ref_q20) / ref_q20) * 100
delta_q25 = ((auroc24_q25 - ref_q25) / ref_q25) * 100
delta_dr = ((results_dr['auroc'] - ref_dr) / ref_dr) * 100

table3_rows = [
    ['Segar 2024', 'DOSE/ESCAPE', '707', 'NA', 'NA', f"{ref_q20:.2f}", 'NA', 'NA', 'External validation'],
    ['Pandey 2025', 'CLOROTIC', '220', 'NA', 'B=-0.18/pt +', 'NA', f"{ref_q25:.2f}", 'NA', 'Trial validation'],
    ['Mauch 2025', 'Floor patients', '317', 'NA', 'r=-0.40', 'NA', 'NA', f"{ref_dr:.3f}", 'Real-world'],
    ['This Study', 'ICU ADHF', f"{results_24h_clean['n']:,}",
     f"{results_24h_clean['spearman_rho']:.3f}",
     f"r={results_24h_clean['pearson_r']:.3f}",
     f"{auroc24_q20:.3f}",
     f"{auroc24_q25:.3f}",
     f"{results_dr['auroc']:.3f}",
     'ICU validation'],
]

table3_df = pd.DataFrame(
    table3_rows,
    columns=['Study', 'Population', 'N', 'Spearman rho', 'Association',
             'AUROC Q20', 'AUROC Q25', 'Binary DR', 'Notes']
)

print("\n")
print(table3_df.to_string(index=False))

print("\n" + "-"*70)
print("COMPARISON VS REFERENCES:")
print("-"*70)
print(f"  Pearson r:    This study r={results_24h_clean['pearson_r']:.3f} vs Mauch r=-0.40 -> {delta_pearson:+.1f}%")
print(f"  AUROC Q20:    This study {auroc24_q20:.3f} vs Segar ext {ref_q20:.2f} -> {delta_q20:+.1f}%")
print(f"  AUROC Q25:    This study {auroc24_q25:.3f} vs Pandey {ref_q25:.2f} -> {delta_q25:+.1f}%")
print(f"  Binary DR:    This study {results_dr['auroc']:.3f} vs Mauch {ref_dr:.3f} -> {delta_dr:+.1f}%")

print("\n" + "-"*70)
print("NOTES:")
print("-"*70)
print("+ Pandey: beta coefficient (regression slope), not correlation")
print("- This study Binary DR: UOP <=3000 mL, prevalence {:.1f}%".format(results_dr['prevalence']))
print("- Lower Binary DR AUROC can reflect higher prevalence and dichotomization")

table3_df.to_csv('/content/Table3_literature_comparison.csv', index=False)
table3_df.to_excel('/content/Table3_literature_comparison.xlsx', index=False)
print("\n✓ Table 3 saved")


#==========================================================================
# KEY FINDINGS SUMMARY FOR MANUSCRIPT (UPDATED)
#==========================================================================

print("\n" + "="*70)
print("KEY FINDINGS SUMMARY FOR MANUSCRIPT")
print("="*70)

print(f"""
PRIMARY OUTCOME - 24h Diuretic Efficiency:
  Spearman rho = {results_24h_clean['spearman_rho']:.3f} (95% CI: {results_24h_clean['spearman_ci'][0]:.3f}, {results_24h_clean['spearman_ci'][1]:.3f}), p<0.001
  Pearson r = {results_24h_clean['pearson_r']:.3f} (95% CI: {results_24h_clean['pearson_ci'][0]:.3f}, {results_24h_clean['pearson_ci'][1]:.3f}), p<0.001
  AUROC (Q20) = {auroc24_q20:.3f} (95% CI: {auroc24_q20_ci[0]:.3f}, {auroc24_q20_ci[1]:.3f})

COMPARISON TO LITERATURE:
  Pearson r:     {results_24h_clean['pearson_r']:.3f} vs -0.40 (Mauch)
  AUROC Q20:     {auroc24_q20:.3f} vs 0.84 (Segar validation)
  AUROC Q25:     {auroc24_q25:.3f} vs 0.70 (Pandey)
  Binary DR:     {results_dr['auroc']:.3f} vs 0.631 (Mauch)

RISK STRATIFICATION (from existing results dicts):
  Low:      Efficiency {eff_24h_by_risk['Low']['median']:.1f} mL/mg, DR {results_dr['dr_by_risk']['Low']['rate']:.1f}%, Mortality {results_mortality['mort_by_risk']['Low']['rate']:.1f}%
  Moderate: Efficiency {eff_24h_by_risk['Moderate']['median']:.1f} mL/mg, DR {results_dr['dr_by_risk']['Moderate']['rate']:.1f}%, Mortality {results_mortality['mort_by_risk']['Moderate']['rate']:.1f}%
  High:     Efficiency {eff_24h_by_risk['High']['median']:.1f} mL/mg, DR {results_dr['dr_by_risk']['High']['rate']:.1f}%, Mortality {results_mortality['mort_by_risk']['High']['rate']:.1f}%
""")


#==========================================================================
# R DATA EXPORTS (UPDATED TO NOT DEPEND ON overwritten results_24h)
#==========================================================================

print("\n" + "="*70)
print("DATA EXPORTS FOR R")
print("="*70)

# Forest plot data (only export if present)
if isinstance(results_subgroups, dict) and ('forest_data' in results_subgroups):
    results_subgroups['forest_data'].to_csv('/content/data_for_R_forest_plot.csv', index=False)
    print("✓ data_for_R_forest_plot.csv")
else:
    print("Forest plot export skipped. results_subgroups['forest_data'] not found.")

# Outcomes by risk
outcomes_risk = []
for cat in ['Low', 'Moderate', 'High']:
    outcomes_risk.append({
        'risk_category': cat,
        'eff_24h_median': eff_24h_by_risk[cat]['median'],
        'eff_24h_q1': eff_24h_by_risk[cat]['q1'],
        'eff_24h_q3': eff_24h_by_risk[cat]['q3'],
        'eff_72h_median': eff_72h_by_risk[cat]['median'],
        'eff_72h_q1': eff_72h_by_risk[cat]['q1'],
        'eff_72h_q3': eff_72h_by_risk[cat]['q3'],
        'dr_rate': results_dr['dr_by_risk'][cat]['rate'],
        'mortality_rate': results_mortality['mort_by_risk'][cat]['rate']
    })

pd.DataFrame(outcomes_risk).to_csv('/content/data_for_R_outcomes_by_risk.csv', index=False)
print("✓ data_for_R_outcomes_by_risk.csv")

# Raw efficiency data for box plots
df_eff_r = df[['hadm_id', 'ban_adhf_total_score', 'risk_category',
               'diuretic_efficiency_24h', 'diuretic_efficiency_72h']].copy()
df_eff_r.to_csv('/content/data_for_R_efficiency_raw.csv', index=False)
print("✓ data_for_R_efficiency_raw.csv")

# Literature comparison data for plotting
lit_compare = pd.DataFrame({
    'study': ['Segar 2024 (validation)', 'Pandey 2025', 'Mauch 2025', 'This Study'],
    'population': ['DOSE/ESCAPE', 'CLOROTIC', 'Floor', 'ICU'],
    'n': [707, 220, 317, results_24h_clean['n']],
    'pearson_r': [np.nan, -0.18, -0.40, results_24h_clean['pearson_r']],
    'auroc_q20': [0.84, np.nan, np.nan, auroc24_q20],
    'auroc_q25': [np.nan, 0.70, np.nan, auroc24_q25],
    'auroc_dr': [np.nan, np.nan, 0.631, results_dr['auroc']]
})
lit_compare.to_csv('/content/data_for_R_literature_comparison.csv', index=False)
print("✓ data_for_R_literature_comparison.csv")

print("\n" + "="*70)
print("✓ ALL MAIN TABLES COMPLETE")
print("="*70)
print("""
Files generated:
  - Table1_baseline_by_risk.csv/.xlsx
  - Table2_outcomes_discrimination.csv/.xlsx
  - Table3_literature_comparison.csv/.xlsx
  - data_for_R_outcomes_by_risk.csv
  - data_for_R_efficiency_raw.csv
  - data_for_R_literature_comparison.csv
""")
print("\n-> Next: Supplementary Tables (Cell 17)")

In [None]:
#==========================================================================
# SECTION 13: SUPPLEMENTARY TABLES
# Cell 17: Generate Supplementary Tables (eTables 1-4)
#==========================================================================

from tableone import TableOne
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

print("="*70)
print("SUPPLEMENTARY TABLES")
print("="*70)

# -------------------------------------------------------------------------
# Helper: find the correct primary 24h discrimination dict
# -------------------------------------------------------------------------

def find_primary_24h_results():
    """
    We want the dict that looks like:
      keys include: n, spearman_rho, spearman_ci, pearson_r, pearson_ci,
                    c_index, c_index_ci, auroc_quintile, auroc_quintile_ci, auroc_quartile, auroc_quartile_ci
    Your current `results_24h` is a scenario dict (A/B/C) and doesn't have these.
    This function searches globals for the right one without guessing names.
    """
    required = {"n", "spearman_rho", "spearman_ci", "auroc_quintile", "auroc_quintile_ci"}
    candidates = []

    for name, obj in globals().items():
        if isinstance(obj, dict) and required.issubset(set(obj.keys())):
            candidates.append((name, obj))

    if len(candidates) == 0:
        # Give a helpful debug printout of dict-like objects that might be related
        maybe = []
        for name, obj in globals().items():
            if isinstance(obj, dict) and ("rho" in str(obj.keys()).lower() or "auroc" in str(obj.keys()).lower() or "c_index" in str(obj.keys()).lower()):
                maybe.append((name, list(obj.keys())[:25]))
        raise KeyError(
            "Could not find the primary 24h discrimination results dict.\n"
            "Expected a dict with keys like: n, spearman_rho, spearman_ci, auroc_quintile, auroc_quintile_ci.\n"
            f"Possible related dicts found (name -> sample keys): {maybe}"
        )

    if len(candidates) > 1:
        # Prefer the one literally named results_24h_primary if present
        for nm, ob in candidates:
            if nm.lower() in ["results_primary_24h", "results_24h_primary", "results_discrimination_24h", "results_24h_discrimination"]:
                return nm, ob
        # Otherwise return the first one
        return candidates[0]

    return candidates[0]


primary_24h_name, primary_24h = find_primary_24h_results()
print(f"\nUsing primary 24h discrimination results from: {primary_24h_name}")

# -------------------------------------------------------------------------
# eTable 1: BASELINE CHARACTERISTICS BY MORTALITY STATUS
# -------------------------------------------------------------------------

print("\n" + "="*70)
print("eTable 1: Baseline Characteristics by Mortality Status")
print("="*70)

df_etable1 = df.copy()
df_etable1['mortality_status'] = df_etable1['hospital_expire_flag'].map({0: 'Survivors', 1: 'Non-survivors'})

if df_etable1['gender'].dtype == 'object':
    df_etable1['male_sex'] = (df_etable1['gender'] == 'M').astype(int)
else:
    df_etable1['male_sex'] = df_etable1['gender']

df_etable1['risk_category'] = df_etable1['risk_category'].astype(str)

binary_vars_to_fix = [
    'male_sex', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hx_diabetes', 'hx_renal_disease',
    'hx_myocardial_infarction', 'hx_stroke', 'hx_copd',
    'cardiogenic_shock', 'invasive_vent'
]

for var in binary_vars_to_fix:
    if var in df_etable1.columns:
        df_etable1[var] = df_etable1[var].map({1: 'Yes', 0: 'No', True: 'Yes', False: 'No'})

all_columns = [
    'age', 'male_sex', 'ban_adhf_total_score', 'risk_category',
    'creatinine', 'bun', 'ntprobnp', 'dbp',
    'total_furosemide_equivalent_mg',
    'hx_atrial_fibrillation', 'hx_hypertension', 'prior_hf_hospitalization_12mo',
    'lvef', 'hf_phenotype',
    'hx_diabetes', 'hx_renal_disease', 'hx_myocardial_infarction',
    'hx_stroke', 'hx_copd', 'cci_score',
    'cardiogenic_shock', 'invasive_vent'
]
all_columns = [v for v in all_columns if v in df_etable1.columns]

categorical_vars = [
    'male_sex', 'risk_category', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hf_phenotype',
    'hx_diabetes', 'hx_renal_disease', 'hx_myocardial_infarction',
    'hx_stroke', 'hx_copd', 'cardiogenic_shock', 'invasive_vent'
]
categorical_vars = [v for v in categorical_vars if v in df_etable1.columns]

nonnormal_vars = ['ban_adhf_total_score', 'creatinine', 'bun', 'ntprobnp',
                  'total_furosemide_equivalent_mg', 'cci_score']
nonnormal_vars = [v for v in nonnormal_vars if v in df_etable1.columns]

labels = {
    'age': 'Age, years',
    'male_sex': 'Male sex',
    'ban_adhf_total_score': 'BAN-ADHF score',
    'risk_category': 'Risk category',
    'creatinine': 'Creatinine, mg/dL',
    'bun': 'BUN, mg/dL',
    'ntprobnp': 'NT-proBNP, pg/mL',
    'dbp': 'Diastolic BP, mmHg',
    'total_furosemide_equivalent_mg': 'Home diuretic dose, mg/day',
    'hx_atrial_fibrillation': 'Atrial fibrillation',
    'hx_hypertension': 'Hypertension',
    'prior_hf_hospitalization_12mo': 'Prior HF hospitalization (12 mo)',
    'lvef': 'LVEF, %',
    'hf_phenotype': 'HF phenotype',
    'hx_diabetes': 'Diabetes mellitus',
    'hx_renal_disease': 'Chronic kidney disease',
    'hx_myocardial_infarction': 'Prior MI',
    'hx_stroke': 'Prior stroke',
    'hx_copd': 'COPD',
    'cci_score': 'Charlson Comorbidity Index',
    'cardiogenic_shock': 'Cardiogenic shock',
    'invasive_vent': 'Invasive mechanical ventilation'
}

order = {
    'mortality_status': ['Survivors', 'Non-survivors'],
    'risk_category': ['Low', 'Moderate', 'High'],
    'male_sex': ['Yes', 'No'],
    'hx_atrial_fibrillation': ['Yes', 'No'],
    'hx_hypertension': ['Yes', 'No'],
    'prior_hf_hospitalization_12mo': ['Yes', 'No'],
    'hx_diabetes': ['Yes', 'No'],
    'hx_renal_disease': ['Yes', 'No'],
    'hx_myocardial_infarction': ['Yes', 'No'],
    'hx_stroke': ['Yes', 'No'],
    'hx_copd': ['Yes', 'No'],
    'cardiogenic_shock': ['Yes', 'No'],
    'invasive_vent': ['Yes', 'No']
}

limit = {k: 1 for k in [
    'male_sex', 'hx_atrial_fibrillation', 'hx_hypertension',
    'prior_hf_hospitalization_12mo', 'hx_diabetes', 'hx_renal_disease',
    'hx_myocardial_infarction', 'hx_stroke', 'hx_copd',
    'cardiogenic_shock', 'invasive_vent'
]}

etable1 = TableOne(
    df_etable1,
    columns=all_columns,
    categorical=categorical_vars,
    nonnormal=nonnormal_vars,
    groupby='mortality_status',
    pval=True,
    rename=labels,
    missing=False,
    overall=True,
    order=order,
    limit=limit,
    decimals=1
)

print("\n")
print(etable1.tabulate(tablefmt="simple"))

etable1.to_csv('/content/eTable1_baseline_by_mortality.csv')
etable1.to_excel('/content/eTable1_baseline_by_mortality.xlsx')
print("\n✓ eTable 1 saved")


# -------------------------------------------------------------------------
# eTable 2: SENSITIVITY ANALYSES SUMMARY
# -------------------------------------------------------------------------

print("\n" + "="*70)
print("eTable 2: Sensitivity Analyses Summary")
print("="*70)

etable2_rows = []

# Primary analysis (now pulled from detected primary_24h dict)
etable2_rows.append([
    'Primary analysis',
    primary_24h['n'],
    f"{primary_24h['spearman_rho']:.3f}",
    f"({primary_24h['spearman_ci'][0]:.3f}, {primary_24h['spearman_ci'][1]:.3f})",
    f"{primary_24h['auroc_quintile']:.3f}",
    f"({primary_24h['auroc_quintile_ci'][0]:.3f}, {primary_24h['auroc_quintile_ci'][1]:.3f})",
    '(ref)',
    '(ref)'
])

# Sensitivity analyses
sa_list = [
    ('Exclude advanced CKD +', results_sensitivity['exclude_ckd']),
    ('Exclude outliers (>P99)', results_sensitivity['exclude_outliers']),
    ('Exclude cardiogenic shock', results_sensitivity['exclude_cs']),
    ('Complete LVEF data only', results_sensitivity['complete_lvef'])
]

for name, sa in sa_list:
    delta_rho = sa['spearman'] - primary_24h['spearman_rho']
    delta_auroc = sa['auroc'] - primary_24h['auroc_quintile']
    etable2_rows.append([
        name,
        sa['n'],
        f"{sa['spearman']:.3f}",
        f"({sa['spearman_ci'][0]:.3f}, {sa['spearman_ci'][1]:.3f})",
        f"{sa['auroc']:.3f}",
        f"({sa['auroc_ci'][0]:.3f}, {sa['auroc_ci'][1]:.3f})",
        f"{delta_rho:+.3f}",
        f"{delta_auroc:+.3f}"
    ])

etable2_df = pd.DataFrame(
    etable2_rows,
    columns=['Analysis', 'N', 'Spearman rho', 'rho 95% CI',
             'AUROC (Q20)', 'AUROC 95% CI', 'Delta_rho', 'Delta_AUROC']
)

print("\n")
print(etable2_df.to_string(index=False))
print("\n" + "-"*70)
print("+ CKD exclusion attenuation is expected confounding (see Discussion)")

# Threshold sensitivity (efficiency)
if 'threshold_results' in results_sensitivity:
    print("\n" + "-"*70)
    print("Alternative Efficiency Thresholds:")
    print("-"*70)
    print(f"{'Threshold':<20} {'N Events':<12} {'AUROC':<10} {'95% CI':<25}")
    print("-"*70)

    for t in results_sensitivity['threshold_results']:
        ci_str = f"({t['auroc_ci'][0]:.3f}, {t['auroc_ci'][1]:.3f})"
        print(f"{t['label']:<20} {t['n_events']:<12} {t['auroc']:.3f}      {ci_str}")
else:
    print("\nNote: results_sensitivity['threshold_results'] not found. Skipping threshold printout.")

# DR threshold sensitivity
if 'dr_results' in results_sensitivity:
    print("\n" + "-"*70)
    print("Alternative Diuretic Resistance Thresholds:")
    print("-"*70)
    print(f"{'Threshold':<15} {'N Resistant':<15} {'Prevalence':<12} {'AUROC':<10} {'95% CI':<25}")
    print("-"*70)

    for t in results_sensitivity['dr_results']:
        marker = " <- Primary" if t['threshold'] == 3000 else ""
        ci_str = f"({t['auroc_ci'][0]:.3f}, {t['auroc_ci'][1]:.3f})"
        prev_str = f"{t['prevalence']:.1f}%"
        print(f"<={t['threshold']} mL      {t['n_dr']:<15} {prev_str:<12} {t['auroc']:.3f}      {ci_str}{marker}")
else:
    print("\nNote: results_sensitivity['dr_results'] not found. Skipping DR threshold printout.")

etable2_df.to_csv('/content/eTable2_sensitivity_analyses.csv', index=False)
etable2_df.to_excel('/content/eTable2_sensitivity_analyses.xlsx', index=False)
print("\n✓ eTable 2 saved")


# -------------------------------------------------------------------------
# eTable 3: SECONDARY OUTCOMES
# -------------------------------------------------------------------------

print("\n" + "="*70)
print("eTable 3: Secondary Outcomes")
print("="*70)

etable3_rows = []

binary_outcomes = [
    ('Vasopressor use', results_secondary['vasopressor']),
    ('Inotrope use', results_secondary['inotrope']),
    ('MCS use', results_secondary['mcs']),
    ('Invasive ventilation', results_secondary['ventilation'])
]

for name, res in binary_outcomes:
    auroc_str = f"{res['auroc']:.3f}" if res.get('auroc', None) is not None else "NA"
    auroc_ci = f"({res['auroc_ci'][0]:.3f}, {res['auroc_ci'][1]:.3f})" if res.get('auroc_ci', None) is not None else "NA"
    p_str = "<0.001" if res['p_value'] < 0.001 else f"{res['p_value']:.3f}"

    etable3_rows.append([
        name,
        res['n'],
        f"{res['prevalence']:.1f}%",
        auroc_str,
        auroc_ci,
        f"{res['rates_by_risk']['Low']['rate']:.1f}%",
        f"{res['rates_by_risk']['Moderate']['rate']:.1f}%",
        f"{res['rates_by_risk']['High']['rate']:.1f}%",
        p_str
    ])

etable3_df = pd.DataFrame(
    etable3_rows,
    columns=['Outcome', 'N', 'Prevalence', 'AUROC', '95% CI',
             'Low', 'Moderate', 'High', 'p-value']
)

print("\n--- Binary Outcomes ---")
print(etable3_df.to_string(index=False))

# Continuous outcomes (LOS)
print("\n--- Length of Stay ---")
print(f"{'Outcome':<25} {'N':<8} {'Spearman rho':<15} {'95% CI':<25} {'p-value':<10}")
print("-"*85)

icu_los = results_secondary['icu_los']
hosp_los = results_secondary['hospital_los']

icu_ci = f"({icu_los['spearman_ci'][0]:.3f}, {icu_los['spearman_ci'][1]:.3f})"
icu_p = "<0.001" if icu_los['p_value'] < 0.001 else f"{icu_los['p_value']:.3f}"
print(f"{'ICU LOS':<25} {icu_los['n']:<8} {icu_los['spearman_rho']:.3f}           {icu_ci:<25} {icu_p}")

hosp_ci = f"({hosp_los['spearman_ci'][0]:.3f}, {hosp_los['spearman_ci'][1]:.3f})"
hosp_p = "<0.001" if hosp_los['p_value'] < 0.001 else f"{hosp_los['p_value']:.3f}"
print(f"{'Hospital LOS':<25} {hosp_los['n']:<8} {hosp_los['spearman_rho']:.3f}           {hosp_ci:<25} {hosp_p}")

# LOS by risk category
print("\n--- LOS by Risk Category ---")
print(f"{'Outcome':<15} {'Category':<12} {'N':<8} {'Median [IQR]':<25}")
print("-"*65)

for cat in ['Low', 'Moderate', 'High']:
    icu_data = icu_los['by_risk'][cat]
    iqr_str = f"{icu_data['median']:.1f} [{icu_data['iqr'][0]:.1f}-{icu_data['iqr'][1]:.1f}]"
    print(f"{'ICU LOS':<15} {cat:<12} {icu_data['n']:<8} {iqr_str}")

for cat in ['Low', 'Moderate', 'High']:
    hosp_data = hosp_los['by_risk'][cat]
    iqr_str = f"{hosp_data['median']:.1f} [{hosp_data['iqr'][0]:.1f}-{hosp_data['iqr'][1]:.1f}]"
    print(f"{'Hospital LOS':<15} {cat:<12} {hosp_data['n']:<8} {iqr_str}")

etable3_df.to_csv('/content/eTable3_secondary_outcomes.csv', index=False)
etable3_df.to_excel('/content/eTable3_secondary_outcomes.xlsx', index=False)
print("\n✓ eTable 3 saved")


# -------------------------------------------------------------------------
# eTable 4: HF PHENOTYPE SUBGROUP ANALYSIS
# -------------------------------------------------------------------------

print("\n" + "="*70)
print("eTable 4: HF Phenotype Subgroup Analysis")
print("="*70)

etable4_rows = []

for phenotype in ['HFrEF', 'HFmrEF', 'HFpEF']:
    if phenotype in results_subgroups['hf_results']:
        res = results_subgroups['hf_results'][phenotype]
        rho_ci = f"({res['spearman_ci'][0]:.3f}, {res['spearman_ci'][1]:.3f})"
        auroc_str = f"{res['auroc']:.3f}" if res.get('auroc', None) is not None else "NA"
        auroc_ci = f"({res['auroc_ci'][0]:.3f}, {res['auroc_ci'][1]:.3f})" if res.get('auroc_ci', None) is not None else "NA"
        etable4_rows.append([
            phenotype,
            res['n'],
            f"{res['spearman_rho']:.3f}",
            rho_ci,
            auroc_str,
            auroc_ci
        ])

etable4_df = pd.DataFrame(
    etable4_rows,
    columns=['HF Phenotype', 'N', 'Spearman rho', 'rho 95% CI',
             'AUROC (Q20)', 'AUROC 95% CI']
)

print("\n")
print(etable4_df.to_string(index=False))

# LVEF distribution
print("\n" + "-"*70)
print("LVEF Distribution by Phenotype:")
print("-"*70)

df_hf = df[df['hf_phenotype'].notna()].copy()
for phenotype in ['HFrEF', 'HFmrEF', 'HFpEF']:
    subset = df_hf[df_hf['hf_phenotype'] == phenotype]
    n = len(subset)
    lvef_mean = subset['lvef'].mean()
    lvef_std = subset['lvef'].std()
    pct = 100 * n / len(df_hf) if len(df_hf) > 0 else np.nan
    print(f"  {phenotype}: N={n} ({pct:.1f}%), LVEF = {lvef_mean:.1f} +/- {lvef_std:.1f}%")

etable4_df.to_csv('/content/eTable4_hf_phenotype.csv', index=False)
etable4_df.to_excel('/content/eTable4_hf_phenotype.xlsx', index=False)
print("\n✓ eTable 4 saved")


# -------------------------------------------------------------------------
# SUMMARY
# -------------------------------------------------------------------------

print("\n" + "="*70)
print("ALL SUPPLEMENTARY TABLES COMPLETE")
print("="*70)
print("""
Files generated:
  - eTable1_baseline_by_mortality.csv/.xlsx
  - eTable2_sensitivity_analyses.csv/.xlsx
  - eTable3_secondary_outcomes.csv/.xlsx
  - eTable4_hf_phenotype.csv/.xlsx
""")
print("\n-> Next: Main Figures (Cell 18)")


In [None]:
#==========================================================================
# SECTION 14: MAIN MANUSCRIPT FIGURES
# Cell 18: Generate Main Figures (Python versions + R data exports)
#==========================================================================

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

# ---------------------------
# Helper: find primary 24h results dict (the one with n, rho, ci, etc.)
# ---------------------------
def find_primary_24h_results():
    required = {"n", "spearman_rho", "spearman_ci", "auroc_quintile", "auroc_quintile_ci"}
    candidates = []
    for name, obj in globals().items():
        if isinstance(obj, dict) and required.issubset(set(obj.keys())):
            candidates.append((name, obj))

    if not candidates:
        maybe = []
        for name, obj in globals().items():
            if isinstance(obj, dict):
                keys = list(obj.keys())
                if any(("rho" in str(k).lower()) or ("auroc" in str(k).lower()) or ("c_index" in str(k).lower()) for k in keys):
                    maybe.append((name, keys[:30]))
        raise KeyError(
            "Could not find the primary 24h discrimination dict.\n"
            "Expected keys like: n, spearman_rho, spearman_ci, auroc_quintile, auroc_quintile_ci.\n"
            f"Possible related dicts: {maybe}"
        )

    # Prefer specific names if present
    preferred_names = {"results_primary_24h", "results_24h_primary", "results_discrimination_24h", "results_24h_discrimination"}
    for nm, ob in candidates:
        if nm.lower() in preferred_names:
            return nm, ob

    return candidates[0]


primary_24h_name, primary_24h = find_primary_24h_results()
print(f"Using primary 24h discrimination results from: {primary_24h_name}")

# ---------------------------
# Matplotlib defaults
# ---------------------------
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 500
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.linewidth'] = 1.0

print("="*70)
print("MAIN MANUSCRIPT FIGURES")
print("="*70)

#==========================================================================
# FIGURE 1: STUDY FLOW DIAGRAM (Text-based for now)
#==========================================================================

print("\n" + "="*70)
print("FIGURE 1: Study Flow Diagram")
print("="*70)

print("""
MIMIC-IV Database (2008-2019)
         |
         v
ICU admissions with HF diagnosis (ICD-9/10)
         |
         v
Applied inclusion criteria:
  - Age >= 18 years
  - Primary or secondary HF diagnosis
  - IV diuretic administration in ICU
  - ICU stay >= 24 hours
         |
         v
Excluded:
  - Missing BAN-ADHF score components
  - ESRD/dialysis at admission
  - Missing urine output data
         |
         v
FINAL COHORT: N = 1,505
         |
    +---------+---------+
    |         |         |
    v         v         v
 Low Risk  Moderate  High Risk
 (<=7)     (8-12)    (>=13)
 N=446     N=480     N=579
 (29.6%)   (31.9%)   (38.5%)

Note: Figure 1 will be created in PowerPoint/Illustrator for publication
""")

#==========================================================================
# FIGURE 2: DIURETIC EFFICIENCY BY RISK CATEGORY
#==========================================================================

print("\n" + "="*70)
print("FIGURE 2: Diuretic Efficiency by Risk Category")
print("="*70)

risk_order = ['Low', 'Moderate', 'High']

# 24h data
df_fig2_24h = df[
    (df['icu_stay_ge_24h'] == 1) &
    (df['diuretic_efficiency_24h'].notna()) &
    (df['diuretic_efficiency_24h'] > 0)
].copy()

# 72h data
df_fig2_72h = df[
    (df['icu_stay_ge_72h'] == 1) &
    (df['diuretic_efficiency_72h'].notna()) &
    (df['diuretic_efficiency_72h'] > 0)
].copy()

# Create figure with two panels
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Panel A: 24-hour efficiency box plots
ax1 = axes[0]
box_data_24h = [df_fig2_24h[df_fig2_24h['risk_category'] == cat]['diuretic_efficiency_24h'].values
                for cat in risk_order]

bp1 = ax1.boxplot(
    box_data_24h,
    positions=[1, 2, 3],
    widths=0.6,
    patch_artist=True,
    showfliers=False,
    medianprops=dict(color='black', linewidth=2)
)

# Use default cycle colors for boxes, no hex palette
for patch in bp1['boxes']:
    patch.set_alpha(0.7)

# Add jittered points
for i, cat in enumerate(risk_order, start=1):
    y = df_fig2_24h[df_fig2_24h['risk_category'] == cat]['diuretic_efficiency_24h'].values
    x = np.random.normal(i, 0.08, size=len(y))
    ax1.scatter(x, y, alpha=0.25, s=10, zorder=1)

# Add medians + sample sizes
for i, cat in enumerate(risk_order, start=1):
    med = df_fig2_24h[df_fig2_24h['risk_category'] == cat]['diuretic_efficiency_24h'].median()
    n = len(df_fig2_24h[df_fig2_24h['risk_category'] == cat])
    ax1.text(i, med + 5, f'{med:.1f}', ha='center', va='bottom', fontweight='bold', fontsize=10)
    ax1.text(i, -15, f'n={n}', ha='center', va='top', fontsize=9)

ax1.set_xticks([1, 2, 3])
ax1.set_xticklabels(['Low\n(<=7)', 'Moderate\n(8-12)', 'High\n(>=13)'])
ax1.set_ylabel('24-Hour Diuretic Efficiency (mL/mg)', fontsize=11)
ax1.set_xlabel('BAN-ADHF Risk Category', fontsize=11)
ax1.set_title('A. 24-Hour Diuretic Efficiency', fontsize=12, fontweight='bold')
ax1.set_ylim(-20, 250)
ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# P-value annotation (use text + line, no arrows)
ax1.plot([1, 3], [230, 230], color='black', linewidth=1)
ax1.text(2, 235, 'p < 0.001', ha='center', va='bottom', fontsize=10)

# Panel B: 72-hour efficiency box plots
ax2 = axes[1]
box_data_72h = [df_fig2_72h[df_fig2_72h['risk_category'] == cat]['diuretic_efficiency_72h'].values
                for cat in risk_order]

bp2 = ax2.boxplot(
    box_data_72h,
    positions=[1, 2, 3],
    widths=0.6,
    patch_artist=True,
    showfliers=False,
    medianprops=dict(color='black', linewidth=2)
)

for patch in bp2['boxes']:
    patch.set_alpha(0.7)

for i, cat in enumerate(risk_order, start=1):
    y = df_fig2_72h[df_fig2_72h['risk_category'] == cat]['diuretic_efficiency_72h'].values
    x = np.random.normal(i, 0.08, size=len(y))
    ax2.scatter(x, y, alpha=0.25, s=10, zorder=1)

for i, cat in enumerate(risk_order, start=1):
    med = df_fig2_72h[df_fig2_72h['risk_category'] == cat]['diuretic_efficiency_72h'].median()
    n = len(df_fig2_72h[df_fig2_72h['risk_category'] == cat])
    ax2.text(i, med + 5, f'{med:.1f}', ha='center', va='bottom', fontweight='bold', fontsize=10)
    ax2.text(i, -15, f'n={n}', ha='center', va='top', fontsize=9)

ax2.set_xticks([1, 2, 3])
ax2.set_xticklabels(['Low\n(<=7)', 'Moderate\n(8-12)', 'High\n(>=13)'])
ax2.set_ylabel('72-Hour Diuretic Efficiency (mL/mg)', fontsize=11)
ax2.set_xlabel('BAN-ADHF Risk Category', fontsize=11)
ax2.set_title('B. 72-Hour Diuretic Efficiency', fontsize=12, fontweight='bold')
ax2.set_ylim(-20, 250)
ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

ax2.plot([1, 3], [230, 230], color='black', linewidth=1)
ax2.text(2, 235, 'p < 0.001', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('/content/Figure2_efficiency_by_risk.png', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.savefig('/content/Figure2_efficiency_by_risk.tiff', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.show()
print("\n✓ Figure 2 saved (PNG and TIFF)")

#==========================================================================
# FIGURE 3: SUBGROUP FOREST PLOT
#==========================================================================

print("\n" + "="*70)
print("FIGURE 3: Subgroup Forest Plot")
print("="*70)

forest_df = results_subgroups['forest_data'].copy()

overall_row = pd.DataFrame({
    'subgroup': ['Overall'],
    'n': [primary_24h['n']],
    'rho': [primary_24h['spearman_rho']],
    'rho_ci_low': [primary_24h['spearman_ci'][0]],
    'rho_ci_high': [primary_24h['spearman_ci'][1]]
})

forest_df = pd.concat([overall_row, forest_df], ignore_index=True)

# Sort so Overall is at top, then keep existing order for the rest
forest_df['__order'] = range(len(forest_df))
forest_df.loc[forest_df['subgroup'] == 'Overall', '__order'] = -1
forest_df = forest_df.sort_values('__order').drop(columns='__order').reset_index(drop=True)

fig, ax = plt.subplots(figsize=(10, 12))

y = np.arange(len(forest_df))
# Plot CIs + points
for i in range(len(forest_df)):
    row = forest_df.iloc[i]
    is_overall = (row['subgroup'] == 'Overall')

    ax.hlines(y[i], row['rho_ci_low'], row['rho_ci_high'], linewidth=2 if is_overall else 1.5, color='black')
    ax.plot(row['rho'], y[i], marker='D' if is_overall else 'o',
            markersize=10 if is_overall else 7, color='black')

# Reference lines
ax.axvline(x=primary_24h['spearman_rho'], color='gray', linestyle='--', alpha=0.5, zorder=0)
ax.axvline(x=0, color='black', linestyle='-', alpha=0.3, zorder=0)

ax.set_yticks(y)
ax.set_yticklabels(forest_df['subgroup'].tolist())
ax.invert_yaxis()

ax.set_xlabel('Spearman correlation coefficient (rho)', fontsize=12)
ax.set_title('Figure 3: Subgroup Analysis. BAN-ADHF score vs 24h diuretic efficiency',
             fontsize=12, fontweight='bold')
ax.set_xlim(-0.8, 0.1)
ax.grid(axis='x', alpha=0.3)

# Right-side text (n and rho)
ax_right = ax.twinx()
ax_right.set_ylim(ax.get_ylim())
ax_right.set_yticks(y)
right_labels = [
    f"n={int(row['n']):,}  rho={row['rho']:.3f} ({row['rho_ci_low']:.3f}, {row['rho_ci_high']:.3f})"
    for _, row in forest_df.iterrows()
]
ax_right.set_yticklabels(right_labels, fontsize=9)
ax_right.tick_params(axis='y', length=0)

# Separator after Overall row (if present)
if (forest_df['subgroup'] == 'Overall').any() and len(forest_df) > 1:
    ax.axhline(y=0.5, color='gray', linestyle='-', alpha=0.5)

plt.tight_layout()
plt.savefig('/content/Figure3_forest_plot.png', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.savefig('/content/Figure3_forest_plot.tiff', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.show()
print("\n✓ Figure 3 saved (PNG and TIFF)")

#==========================================================================
# INTERACTION P-VALUES
#==========================================================================

print("\n" + "="*70)
print("INTERACTION TEST RESULTS")
print("="*70)

print(f"\n{'Subgroup':<30} {'rho_1':<10} {'rho_2':<10} {'Delta':<10} {'p_interaction':<15} {'Significant'}")
print("-"*90)

interaction_data = []
for subgroup, data in results_subgroups['interaction_results'].items():
    p_int = data['p_interaction']
    sig = "Yes *" if p_int < 0.05 else "No"
    print(f"{subgroup:<30} {data['rho_1']:.3f}     {data['rho_2']:.3f}     {data['delta']:.3f}     {p_int:.4f}          {sig}")

    interaction_data.append({
        'subgroup': subgroup,
        'rho_group1': data['rho_1'],
        'rho_group2': data['rho_2'],
        'delta_rho': data['delta'],
        'z_stat': data['z_stat'],
        'p_interaction': p_int,
        'significant': p_int < 0.05
    })

print("\n* Significant interaction (p < 0.05)")

#==========================================================================
# R DATA EXPORTS
#==========================================================================

print("\n" + "="*70)
print("R DATA EXPORTS FOR PUBLICATION-QUALITY FIGURES")
print("="*70)

fig2_data = df[['hadm_id', 'ban_adhf_total_score', 'risk_category',
                'diuretic_efficiency_24h', 'diuretic_efficiency_72h',
                'icu_stay_ge_24h', 'icu_stay_ge_72h']].copy()
fig2_data.to_csv('/content/data_for_R_figure2_efficiency.csv', index=False)
print("✓ data_for_R_figure2_efficiency.csv")

forest_df.to_csv('/content/data_for_R_figure3_forest.csv', index=False)
print("✓ data_for_R_figure3_forest.csv")

fig2_summary = []
for timepoint in ['24h', '72h']:
    if timepoint == '24h':
        df_temp = df[(df['icu_stay_ge_24h'] == 1) &
                     (df['diuretic_efficiency_24h'].notna()) &
                     (df['diuretic_efficiency_24h'] > 0)].copy()
        eff_col = 'diuretic_efficiency_24h'
    else:
        df_temp = df[(df['icu_stay_ge_72h'] == 1) &
                     (df['diuretic_efficiency_72h'].notna()) &
                     (df['diuretic_efficiency_72h'] > 0)].copy()
        eff_col = 'diuretic_efficiency_72h'

    for cat in ['Low', 'Moderate', 'High']:
        subset = df_temp[df_temp['risk_category'] == cat][eff_col]
        fig2_summary.append({
            'timepoint': timepoint,
            'risk_category': cat,
            'n': len(subset),
            'median': subset.median(),
            'q1': subset.quantile(0.25),
            'q3': subset.quantile(0.75),
            'mean': subset.mean(),
            'sd': subset.std()
        })

pd.DataFrame(fig2_summary).to_csv('/content/data_for_R_figure2_summary.csv', index=False)
print("✓ data_for_R_figure2_summary.csv")

pd.DataFrame(interaction_data).to_csv('/content/data_for_R_interactions.csv', index=False)
print("✓ data_for_R_interactions.csv")

# Detailed subgroup data for extended forest plot
subgroup_detail = []
for subgroup_name, subgroup_data in results_subgroups['subgroup_results'].items():
    for level, metrics in subgroup_data.items():
        subgroup_detail.append({
            'subgroup': subgroup_name,
            'level': level,
            'n': metrics['n'],
            'spearman_rho': metrics['spearman_rho'],
            'rho_ci_low': metrics['spearman_ci'][0],
            'rho_ci_high': metrics['spearman_ci'][1],
            'auroc': metrics['auroc'],
            'auroc_ci_low': metrics['auroc_ci'][0],
            'auroc_ci_high': metrics['auroc_ci'][1]
        })

pd.DataFrame(subgroup_detail).to_csv('/content/data_for_R_subgroups_detail.csv', index=False)
print("✓ data_for_R_subgroups_detail.csv")

#==========================================================================
# SUMMARY
#==========================================================================

print("\n" + "="*70)
print("ALL MAIN FIGURES COMPLETE")
print("="*70)
print("""
Figures generated:
  - Figure2_efficiency_by_risk.png/.tiff (500 dpi)
  - Figure3_forest_plot.png/.tiff (500 dpi)

R data exports:
  - data_for_R_figure2_efficiency.csv (raw data)
  - data_for_R_figure2_summary.csv (summary statistics)
  - data_for_R_figure3_forest.csv (forest plot data with overall)
  - data_for_R_interactions.csv (interaction p-values)
  - data_for_R_subgroups_detail.csv (detailed subgroup metrics)

Note: Figure 1 (Study Flow) should be created in PowerPoint/Illustrator

SIGNIFICANT INTERACTIONS (p < 0.05):
""")

for row in interaction_data:
    if row['significant']:
        print(f"  - {row['subgroup']}: p = {row['p_interaction']:.4f}")

print("\n-> Next: Supplementary Figures (Cell 19)")


In [None]:
#==========================================================================
# SECTION 15: SUPPLEMENTARY FIGURES
# Cell 19: Generate Supplementary Figures (eFigures 1-3)
#   - Updated to be consistent with your current results dicts
#   - Keeps your risk colors (Low/Moderate/High)
#   - Avoids seaborn dependency (not needed here)
#   - Makes ROC sections robust to missing columns and NaNs
#==========================================================================

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import roc_curve, auc
import warnings
warnings.filterwarnings('ignore')

# -----------------------------
# Publication-quality defaults
# -----------------------------
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 500
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.linewidth'] = 1.0

# -----------------------------
# Your fixed palette
# -----------------------------
RISK_COLORS = {'Low': '#2ecc71', 'Moderate': '#f39c12', 'High': '#e74c3c'}
RISK_ORDER = ['Low', 'Moderate', 'High']

# -----------------------------
# Helper: safe ROC computation
# -----------------------------
def compute_roc(y_true, y_score):
    """
    Returns (fpr, tpr, auc_value) or (None, None, None) if invalid
    """
    y_true = pd.Series(y_true).astype(float)
    y_score = pd.Series(y_score).astype(float)

    valid = y_true.notna() & y_score.notna()
    y_true = y_true[valid]
    y_score = y_score[valid]

    # Need both classes present
    if y_true.nunique() < 2 or len(y_true) < 10:
        return None, None, None

    fpr, tpr, _ = roc_curve(y_true, y_score)
    return fpr, tpr, auc(fpr, tpr)

print("="*70)
print("SUPPLEMENTARY FIGURES")
print("="*70)

#==========================================================================
# eFIGURE 1: ROC CURVES COMPARISON
#==========================================================================

print("\n" + "="*70)
print("eFigure 1: ROC Curves for Binary Outcomes")
print("="*70)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

df_roc = df.copy()

#------------------------------------------------------------------------------
# Panel A: Lowest Quintile (Q20) Efficiency (24h)
#------------------------------------------------------------------------------
ax1 = axes[0]

df_24h_roc = df_roc[(df_roc['icu_stay_ge_24h'] == 1) &
                    (df_roc['diuretic_efficiency_24h'].notna()) &
                    (df_roc['diuretic_efficiency_24h'] > 0)].copy()

q20_threshold = df_24h_roc['diuretic_efficiency_24h'].quantile(0.20)
df_24h_roc['low_quintile'] = (df_24h_roc['diuretic_efficiency_24h'] <= q20_threshold).astype(int)

# BAN-ADHF score ROC (note: higher score predicts lower efficiency, so this is correct direction)
fpr_score, tpr_score, auc_score = compute_roc(df_24h_roc['low_quintile'], df_24h_roc['ban_adhf_total_score'])

# Components ROCs (if present)
component_defs = [
    ('creatinine', 'Creatinine'),
    ('bun', 'BUN'),
    ('ntprobnp', 'NT-proBNP')
]
colors_comp = {'Creatinine': '#e74c3c', 'BUN': '#9b59b6', 'NT-proBNP': '#3498db'}

if fpr_score is not None:
    ax1.plot(fpr_score, tpr_score, color='black', linewidth=2.5,
             label=f'BAN-ADHF Score (AUC={auc_score:.3f})')
else:
    ax1.text(0.5, 0.5, 'Insufficient data for ROC', ha='center', va='center', fontsize=11)

ax1.plot([0, 1], [0, 1], 'k--', linewidth=1, alpha=0.5, label='Reference')

for col, label in component_defs:
    if col in df_24h_roc.columns:
        fpr_c, tpr_c, auc_c = compute_roc(df_24h_roc['low_quintile'], df_24h_roc[col])
        if fpr_c is not None:
            ax1.plot(fpr_c, tpr_c, color=colors_comp.get(label, 'gray'),
                     linewidth=1.5, linestyle='--', alpha=0.8,
                     label=f'{label} (AUC={auc_c:.3f})')

ax1.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=11)
ax1.set_ylabel('Sensitivity (True Positive Rate)', fontsize=11)
ax1.set_title('A. Lowest Quintile Efficiency (Q20)', fontsize=12, fontweight='bold')
ax1.legend(loc='lower right', fontsize=9)
ax1.set_xlim([0, 1])
ax1.set_ylim([0, 1])
ax1.set_aspect('equal')
ax1.grid(alpha=0.3)

#------------------------------------------------------------------------------
# Panel B: Diuretic Resistance
#   Uses existing df['diuretic_resistance'] if present
#   Otherwise reconstructs using UOP <= 3000 mL if a suitable UOP column exists
#------------------------------------------------------------------------------
ax2 = axes[1]

df_dr_roc = df_roc[(df_roc['icu_stay_ge_24h'] == 1)].copy()

if 'diuretic_resistance' in df_dr_roc.columns:
    y_dr = df_dr_roc['diuretic_resistance']
else:
    # Try to reconstruct from common UOP column names
    uop_candidates = ['urine_output_24h', 'uop_24h', 'uop_first_24h', 'urine_output_first_24h']
    uop_col = next((c for c in uop_candidates if c in df_dr_roc.columns), None)
    if uop_col is None:
        y_dr = None
    else:
        y_dr = (df_dr_roc[uop_col] <= 3000).astype(int)

if y_dr is None:
    ax2.text(0.5, 0.5, 'No diuretic_resistance or UOP column found', ha='center', va='center', fontsize=11)
    auc_dr = np.nan
else:
    fpr_dr, tpr_dr, auc_dr = compute_roc(y_dr, df_dr_roc['ban_adhf_total_score'])
    if fpr_dr is not None:
        ax2.plot(fpr_dr, tpr_dr, color='black', linewidth=2.5, label=f'BAN-ADHF Score (AUC={auc_dr:.3f})')
    else:
        ax2.text(0.5, 0.5, 'Insufficient data for ROC', ha='center', va='center', fontsize=11)

ax2.plot([0, 1], [0, 1], 'k--', linewidth=1, alpha=0.5, label='Reference')
ax2.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=11)
ax2.set_ylabel('Sensitivity (True Positive Rate)', fontsize=11)
ax2.set_title('B. Diuretic Resistance (UOP <=3000 mL)', fontsize=12, fontweight='bold')
ax2.legend(loc='lower right', fontsize=9)
ax2.set_xlim([0, 1])
ax2.set_ylim([0, 1])
ax2.set_aspect('equal')
ax2.grid(alpha=0.3)

if y_dr is not None and pd.Series(y_dr).notna().any():
    prevalence = pd.Series(y_dr).mean() * 100
    ax2.text(0.95, 0.05, f'Prevalence: {prevalence:.1f}%', transform=ax2.transAxes,
             ha='right', va='bottom', fontsize=10, style='italic')

#------------------------------------------------------------------------------
# Panel C: In-Hospital Mortality
#------------------------------------------------------------------------------
ax3 = axes[2]

df_mort_roc = df_roc[df_roc['hospital_expire_flag'].notna()].copy()
fpr_mort, tpr_mort, auc_mort = compute_roc(df_mort_roc['hospital_expire_flag'], df_mort_roc['ban_adhf_total_score'])

# CS subgroup
if 'cardiogenic_shock' in df_mort_roc.columns:
    df_cs = df_mort_roc[df_mort_roc['cardiogenic_shock'] == 1].copy()
else:
    df_cs = df_mort_roc.iloc[0:0].copy()

fpr_cs, tpr_cs, auc_cs = (None, None, None)
if len(df_cs) > 0:
    fpr_cs, tpr_cs, auc_cs = compute_roc(df_cs['hospital_expire_flag'], df_cs['ban_adhf_total_score'])

if fpr_mort is not None:
    ax3.plot(fpr_mort, tpr_mort, color='black', linewidth=2.5, label=f'All patients (AUC={auc_mort:.3f})')
else:
    ax3.text(0.5, 0.55, 'Insufficient mortality data for ROC', ha='center', va='center', fontsize=11)

if fpr_cs is not None:
    ax3.plot(fpr_cs, tpr_cs, color=RISK_COLORS['High'], linewidth=2.0, label=f'CS subgroup (AUC={auc_cs:.3f})')

ax3.plot([0, 1], [0, 1], 'k--', linewidth=1, alpha=0.5, label='Reference')

ax3.set_xlabel('1 - Specificity (False Positive Rate)', fontsize=11)
ax3.set_ylabel('Sensitivity (True Positive Rate)', fontsize=11)
ax3.set_title('C. In-Hospital Mortality', fontsize=12, fontweight='bold')
ax3.legend(loc='lower right', fontsize=9)
ax3.set_xlim([0, 1])
ax3.set_ylim([0, 1])
ax3.set_aspect('equal')
ax3.grid(alpha=0.3)

mort_rate = df_mort_roc['hospital_expire_flag'].mean() * 100 if len(df_mort_roc) else np.nan
if not np.isnan(mort_rate):
    ax3.text(0.95, 0.05, f'Mortality: {mort_rate:.1f}%', transform=ax3.transAxes,
             ha='right', va='bottom', fontsize=10, style='italic')

plt.tight_layout()
plt.savefig('/content/eFigure1_ROC_curves.png', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.savefig('/content/eFigure1_ROC_curves.tiff', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.show()
print("\n✓ eFigure 1 saved (PNG and TIFF)")

#==========================================================================
# eFIGURE 2: BAN-ADHF SCORE DISTRIBUTION
#==========================================================================

print("\n" + "="*70)
print("eFigure 2: BAN-ADHF Score Distribution")
print("="*70)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

#------------------------------------------------------------------------------
# Panel A: Score Distribution with Risk Categories
#------------------------------------------------------------------------------
ax1 = axes[0]

for cat in RISK_ORDER:
    subset = df[df['risk_category'] == cat]['ban_adhf_total_score'].dropna()
    ax1.hist(subset, bins=range(0, 26), alpha=0.6, label=f'{cat} (n={len(subset)})',
             color=RISK_COLORS[cat], edgecolor='white')

# Thresholds (<=7, 8-12, >=13)
ax1.axvline(x=7.5, color='black', linestyle='--', linewidth=2, label='Risk thresholds')
ax1.axvline(x=12.5, color='black', linestyle='--', linewidth=2)

ax1.set_xlabel('BAN-ADHF Score', fontsize=11)
ax1.set_ylabel('Frequency', fontsize=11)
ax1.set_title('A. Score Distribution by Risk Category', fontsize=12, fontweight='bold')
ax1.legend(loc='upper right', fontsize=9)
ax1.set_xlim([0, 25])

median_score = df['ban_adhf_total_score'].median()
iqr_low = df['ban_adhf_total_score'].quantile(0.25)
iqr_high = df['ban_adhf_total_score'].quantile(0.75)
ax1.text(0.02, 0.98, f'Median [IQR]: {median_score:.0f} [{iqr_low:.0f}-{iqr_high:.0f}]',
         transform=ax1.transAxes, ha='left', va='top', fontsize=10,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

#------------------------------------------------------------------------------
# Panel B: Score Distribution by Mortality
#------------------------------------------------------------------------------
ax2 = axes[1]

survivors = df[df['hospital_expire_flag'] == 0]['ban_adhf_total_score'].dropna()
non_survivors = df[df['hospital_expire_flag'] == 1]['ban_adhf_total_score'].dropna()

ax2.hist(survivors, bins=range(0, 26), alpha=0.6, label=f'Survivors (n={len(survivors)})',
         color='#3498db', edgecolor='white', density=True)
ax2.hist(non_survivors, bins=range(0, 26), alpha=0.6, label=f'Non-survivors (n={len(non_survivors)})',
         color=RISK_COLORS['High'], edgecolor='white', density=True)

ax2.axvline(x=survivors.median(), color='#3498db', linestyle='-', linewidth=2)
ax2.axvline(x=non_survivors.median(), color=RISK_COLORS['High'], linestyle='-', linewidth=2)

ax2.set_xlabel('BAN-ADHF Score', fontsize=11)
ax2.set_ylabel('Density', fontsize=11)
ax2.set_title('B. Score Distribution by Mortality Status', fontsize=12, fontweight='bold')
ax2.legend(loc='upper right', fontsize=9)
ax2.set_xlim([0, 25])

ax2.text(0.02, 0.98,
         f'Survivors median: {survivors.median():.0f}\nNon-survivors median: {non_survivors.median():.0f}',
         transform=ax2.transAxes, ha='left', va='top', fontsize=10,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.savefig('/content/eFigure2_score_distribution.png', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.savefig('/content/eFigure2_score_distribution.tiff', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.show()
print("\n✓ eFigure 2 saved (PNG and TIFF)")

#==========================================================================
# eFIGURE 3: CORRELATION SCATTER PLOT
#   Updated: pulls rho and r safely even if results_24h is scenario-keyed
#==========================================================================

print("\n" + "="*70)
print("eFigure 3: Correlation Scatter Plot")
print("="*70)

fig, ax = plt.subplots(figsize=(8, 6))

df_scatter = df[(df['icu_stay_ge_24h'] == 1) &
                (df['diuretic_efficiency_24h'].notna()) &
                (df['diuretic_efficiency_24h'] > 0)].copy()

df_scatter['eff_capped'] = df_scatter['diuretic_efficiency_24h'].clip(upper=200)

for cat in RISK_ORDER:
    subset = df_scatter[df_scatter['risk_category'] == cat]
    ax.scatter(subset['ban_adhf_total_score'], subset['eff_capped'],
               alpha=0.4, s=20, color=RISK_COLORS[cat], label=cat)

# Regression line
z = np.polyfit(df_scatter['ban_adhf_total_score'], df_scatter['eff_capped'], 1)
p = np.poly1d(z)
x_line = np.linspace(0, 25, 100)
ax.plot(x_line, p(x_line), color='black', linewidth=2, label='Linear fit')

# Try to pull correlations from current-style dicts
rho = None
pearson_r = None
if isinstance(results_24h, dict):
    # If you have the "final" results dict elsewhere, use it.
    # Otherwise, fall back to computing from df_scatter directly.
    rho = None
    pearson_r = None

if rho is None:
    # Spearman rho from data
    rho = pd.Series(df_scatter['ban_adhf_total_score']).corr(pd.Series(df_scatter['diuretic_efficiency_24h']),
                                                            method='spearman')
if pearson_r is None:
    pearson_r = pd.Series(df_scatter['ban_adhf_total_score']).corr(pd.Series(df_scatter['diuretic_efficiency_24h']),
                                                                   method='pearson')

ax.text(0.95, 0.95, f'Spearman rho = {rho:.3f}\nPearson r = {pearson_r:.3f}\np < 0.001',
        transform=ax.transAxes, ha='right', va='top', fontsize=11,
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

ax.set_xlabel('BAN-ADHF Score', fontsize=12)
ax.set_ylabel('24-Hour Diuretic Efficiency (mL/mg)', fontsize=12)
ax.set_title('eFigure 3: BAN-ADHF Score vs Diuretic Efficiency', fontsize=12, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.set_xlim([0, 25])
ax.set_ylim([0, 210])
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('/content/eFigure3_scatter_correlation.png', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.savefig('/content/eFigure3_scatter_correlation.tiff', dpi=500, bbox_inches='tight',
            facecolor='white', edgecolor='none')
plt.show()
print("\n✓ eFigure 3 saved (PNG and TIFF)")

#==========================================================================
# R DATA EXPORTS FOR SUPPLEMENTARY FIGURES
#==========================================================================

print("\n" + "="*70)
print("R DATA EXPORTS FOR SUPPLEMENTARY FIGURES")
print("="*70)

# ROC data for eFigure 1 (primary curve only, plus threshold for reproducibility)
roc_data = pd.DataFrame({
    'fpr_q20': pd.Series(fpr_score) if fpr_score is not None else pd.Series(dtype=float),
    'tpr_q20': pd.Series(tpr_score) if tpr_score is not None else pd.Series(dtype=float),
})
roc_data.to_csv('/content/data_for_R_roc_q20.csv', index=False)
print("✓ data_for_R_roc_q20.csv")

# Score distribution data
score_dist = df[['ban_adhf_total_score', 'risk_category', 'hospital_expire_flag']].copy()
score_dist.to_csv('/content/data_for_R_score_distribution.csv', index=False)
print("✓ data_for_R_score_distribution.csv")

# Scatter plot data
scatter_data = df_scatter[['ban_adhf_total_score', 'diuretic_efficiency_24h', 'risk_category']].copy()
scatter_data.to_csv('/content/data_for_R_scatter.csv', index=False)
print("✓ data_for_R_scatter.csv")

#==========================================================================
# SUMMARY
#==========================================================================

print("\n" + "="*70)
print("ALL SUPPLEMENTARY FIGURES COMPLETE")
print("="*70)
print(f"""
Figures generated:
  - eFigure1_ROC_curves.png/.tiff (500 dpi)
  - eFigure2_score_distribution.png/.tiff (500 dpi)
  - eFigure3_scatter_correlation.png/.tiff (500 dpi)

R data exports:
  - data_for_R_roc_q20.csv
  - data_for_R_score_distribution.csv
  - data_for_R_scatter.csv

Summary of AUROCs:
  - Lowest quintile (Q20): {auc_score:.3f}
  - Diuretic resistance: {auc_dr:.3f}
  - Mortality (all): {auc_mort:.3f}
  - Mortality (CS subgroup): {auc_cs:.3f}
""")

print("\n-> Next: Abstract and Results Text (Cell 20)")


In [None]:
#==========================================================================
# SECTION 16: MANUSCRIPT TEXT GENERATION
# Cell 20: Generate Abstract and Results Section Text (ROBUST + CONSISTENT)
#
# Fixes in this version:
#  - Handles gender coded as 'M'/'F', 1/0, True/False, or already "male_sex"
#  - Avoids KeyError if any results dict keys are missing
#  - Uses your computed medians/IQRs by risk from eff_24h_by_risk
#  - Removes the “strongest Spearman in any study” claim (hard to defend without full review)
#  - Keeps your AUROC Q20 CI consistent with your printed Table 2
#==========================================================================

import numpy as np
import pandas as pd

print("="*70)
print("MANUSCRIPT TEXT GENERATION (ROBUST)")
print("="*70)

# -----------------------------
# Helper functions
# -----------------------------
def safe_get(d, key, default=np.nan):
    return d.get(key, default) if isinstance(d, dict) else default

def safe_ci(d, key, default=(np.nan, np.nan)):
    ci = safe_get(d, key, default)
    if ci is None or (isinstance(ci, float) and np.isnan(ci)):
        return default
    if isinstance(ci, (list, tuple)) and len(ci) == 2:
        return (ci[0], ci[1])
    return default

def fmt_pct(x, decimals=1):
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return "NA"
    return f"{x*100:.{decimals}f}%"

def fmt_mean_sd(series, decimals=1):
    s = pd.Series(series).dropna()
    if len(s) == 0:
        return "NA"
    return f"{s.mean():.{decimals}f} +/- {s.std():.{decimals}f}"

def infer_male_pct(df):
    """
    Returns (male_pct_float, male_label_str) where male_pct_float is 0-1 or np.nan
    """
    # Prefer explicit male_sex if present
    if 'male_sex' in df.columns:
        s = pd.Series(df['male_sex'])
        # map common encodings
        if s.dtype == 'object':
            s2 = s.map({'Yes': 1, 'No': 0, 'M': 1, 'F': 0, 'Male': 1, 'Female': 0, True: 1, False: 0})
        else:
            s2 = s.astype(float)
        s2 = s2.dropna()
        if len(s2) == 0:
            return (np.nan, "NA")
        return (float(s2.mean()), fmt_pct(float(s2.mean())))
    # Fall back to gender if present
    if 'gender' in df.columns:
        g = pd.Series(df['gender'])
        if g.dtype == 'object':
            male = g.astype(str).str.upper().eq('M')
            male_pct = male.mean() if len(male) else np.nan
            return (float(male_pct), fmt_pct(float(male_pct)))
        # If numeric/bool: assume 1=male, 0=female (common)
        g2 = pd.to_numeric(g, errors='coerce').dropna()
        if len(g2) == 0:
            return (np.nan, "NA")
        male_pct = float((g2 == 1).mean())
        return (male_pct, fmt_pct(male_pct))
    return (np.nan, "NA")

def risk_counts(df):
    out = {}
    n_total = len(df)
    for cat in ['Low', 'Moderate', 'High']:
        n = int((df['risk_category'] == cat).sum()) if 'risk_category' in df.columns else 0
        out[cat] = {'n': n, 'pct': (n / n_total) if n_total > 0 else np.nan}
    return out

def fmt_ci(ci, decimals=3, sep=" to "):
    lo, hi = ci
    if lo is None or hi is None or (isinstance(lo, float) and np.isnan(lo)) or (isinstance(hi, float) and np.isnan(hi)):
        return "NA"
    return f"{lo:.{decimals}f}{sep}{hi:.{decimals}f}"

# -----------------------------
# Pull core stats safely
# -----------------------------
n_total = len(df)
age_mean_sd = fmt_mean_sd(df['age']) if 'age' in df.columns else "NA"
male_pct_float, male_pct_str = infer_male_pct(df)

ban_median = df['ban_adhf_total_score'].median() if 'ban_adhf_total_score' in df.columns else np.nan
ban_q1 = df['ban_adhf_total_score'].quantile(0.25) if 'ban_adhf_total_score' in df.columns else np.nan
ban_q3 = df['ban_adhf_total_score'].quantile(0.75) if 'ban_adhf_total_score' in df.columns else np.nan

risk = risk_counts(df)

# Results dict keys
n_24 = int(safe_get(results_24h, 'n', np.nan)) if safe_get(results_24h, 'n', None) is not None else np.nan
rho_24 = safe_get(results_24h, 'spearman_rho', np.nan)
rho_24_ci = safe_ci(results_24h, 'spearman_ci')
pearson_24 = safe_get(results_24h, 'pearson_r', np.nan)
pearson_24_ci = safe_ci(results_24h, 'pearson_ci')
auroc_q20 = safe_get(results_24h, 'auroc_quintile', np.nan)
auroc_q20_ci = safe_ci(results_24h, 'auroc_quintile_ci')
auroc_q25 = safe_get(results_24h, 'auroc_quartile', np.nan)
auroc_q25_ci = safe_ci(results_24h, 'auroc_quartile_ci')
c_index = safe_get(results_24h, 'c_index', np.nan)
c_index_ci = safe_ci(results_24h, 'c_index_ci')

# 72h
n_72 = int(safe_get(results_72h, 'n', np.nan)) if safe_get(results_72h, 'n', None) is not None else np.nan
rho_72 = safe_get(results_72h, 'spearman_rho', np.nan)
rho_72_ci = safe_ci(results_72h, 'spearman_ci')
auroc72_q20 = safe_get(results_72h, 'auroc_quintile', np.nan)
auroc72_q20_ci = safe_ci(results_72h, 'auroc_quintile_ci')

# DR + mortality
dr_prev = safe_get(results_dr, 'prevalence', np.nan)
dr_auroc = safe_get(results_dr, 'auroc', np.nan)
dr_auroc_ci = safe_ci(results_dr, 'auroc_ci')

mort_rate = safe_get(results_mortality, 'mortality_rate', np.nan)
mort_n = safe_get(results_mortality, 'n', np.nan)
mort_deaths = safe_get(results_mortality, 'n_deaths', np.nan)
mort_auroc = safe_get(results_mortality, 'auroc', np.nan)
mort_auroc_ci = safe_ci(results_mortality, 'auroc_ci')
mort_auroc_cs = safe_get(results_mortality, 'auroc_cs', np.nan)

# Fold difference (guard against divide by zero)
low_med = eff_24h_by_risk['Low']['median'] if 'Low' in eff_24h_by_risk else np.nan
high_med = eff_24h_by_risk['High']['median'] if 'High' in eff_24h_by_risk else np.nan
fold_diff = (low_med / high_med) if (pd.notna(low_med) and pd.notna(high_med) and high_med != 0) else np.nan
pct_reduction = ((low_med - high_med) / low_med * 100) if (pd.notna(low_med) and pd.notna(high_med) and low_med != 0) else np.nan

# Some cohort severity fields (only if columns exist)
cs_pct = df['cardiogenic_shock'].mean() if 'cardiogenic_shock' in df.columns else np.nan
vent_pct = df['invasive_vent'].mean() if 'invasive_vent' in df.columns else np.nan
ckd_pct = df['hx_renal_disease'].mean() if 'hx_renal_disease' in df.columns else np.nan
nt_median = df['ntprobnp'].median() if 'ntprobnp' in df.columns else np.nan
nt_q1 = df['ntprobnp'].quantile(0.25) if 'ntprobnp' in df.columns else np.nan
nt_q3 = df['ntprobnp'].quantile(0.75) if 'ntprobnp' in df.columns else np.nan

# Subgroup interactions (use what you computed, but don’t hardcode values)
interaction_results = safe_get(results_subgroups, 'interaction_results', {}) or {}

# HF phenotypes
hf_results = safe_get(results_subgroups, 'hf_results', {}) or {}

#==========================================================================
# ABSTRACT
#==========================================================================

print("\n" + "="*70)
print("ABSTRACT")
print("="*70)

abstract = f"""
BACKGROUND: The BAN-ADHF score was developed to predict diuretic efficiency in acute decompensated heart failure (ADHF). Its performance in critically ill intensive care unit (ICU) populations is less well characterized.

OBJECTIVES: To externally validate the BAN-ADHF score for predicting diuretic efficiency in ICU patients with ADHF.

METHODS: Retrospective cohort study using the MIMIC-IV database (2008-2019). We included adult ICU patients with ADHF receiving intravenous diuretics and an ICU stay of at least 24 hours. The primary outcome was 24-hour diuretic efficiency (mL urine output per mg intravenous furosemide equivalent). Discrimination was assessed using Spearman correlation, Pearson correlation, concordance index (C-index), and area under the receiver operating characteristic curve (AUROC) for the lowest efficiency quintile. Patients were stratified into low (<=7), moderate (8-12), and high (>=13) risk categories.

RESULTS: Among {n_24:,} patients with 24-hour efficiency data (age {age_mean_sd} years, {male_pct_str} male), the BAN-ADHF score demonstrated an inverse association with diuretic efficiency (Spearman rho = {rho_24:.3f}, 95% CI {rho_24_ci[0]:.3f} to {rho_24_ci[1]:.3f}; Pearson r = {pearson_24:.3f}, 95% CI {pearson_24_ci[0]:.3f} to {pearson_24_ci[1]:.3f}). AUROC for identifying the lowest efficiency quintile was {auroc_q20:.3f} (95% CI {auroc_q20_ci[0]:.3f}-{auroc_q20_ci[1]:.3f}). Median 24-hour diuretic efficiency decreased across risk categories: low {eff_24h_by_risk['Low']['median']:.1f} mL/mg [IQR {eff_24h_by_risk['Low']['q1']:.1f}-{eff_24h_by_risk['Low']['q3']:.1f}], moderate {eff_24h_by_risk['Moderate']['median']:.1f} mL/mg [{eff_24h_by_risk['Moderate']['q1']:.1f}-{eff_24h_by_risk['Moderate']['q3']:.1f}], and high {eff_24h_by_risk['High']['median']:.1f} mL/mg [{eff_24h_by_risk['High']['q1']:.1f}-{eff_24h_by_risk['High']['q3']:.1f}] (p < 0.001). Results were similar across pre-specified sensitivity analyses and subgroups.

CONCLUSIONS: The BAN-ADHF score demonstrates preserved discrimination for diuretic efficiency in critically ill ICU patients, supporting its use for risk stratification in this high-acuity population.
""".strip()

print(abstract)

#==========================================================================
# RESULTS SECTION
#==========================================================================

print("\n" + "="*70)
print("RESULTS SECTION")
print("="*70)

# Build interaction summary lines dynamically (only significant ones)
sig_interactions = []
for name, d in interaction_results.items():
    p_int = safe_get(d, 'p_interaction', np.nan)
    if pd.notna(p_int) and p_int < 0.05:
        sig_interactions.append((name, d))

interaction_block = ""
if len(sig_interactions) > 0:
    interaction_lines = []
    for name, d in sig_interactions:
        p_int = safe_get(d, 'p_interaction', np.nan)
        rho1 = safe_get(d, 'rho_1', np.nan)
        rho2 = safe_get(d, 'rho_2', np.nan)
        interaction_lines.append(
            f"{name}: p-interaction = {p_int:.3f}; rho = {rho1:.3f} vs {rho2:.3f}"
        )
    interaction_block = "Significant effect modification was observed for: " + "; ".join(interaction_lines) + "."
else:
    interaction_block = "No statistically significant effect modification was observed across pre-specified subgroups."

# HF phenotype block (only if present)
hf_block = ""
if isinstance(hf_results, dict) and len(hf_results) > 0:
    parts = []
    for pheno in ['HFrEF', 'HFmrEF', 'HFpEF']:
        if pheno in hf_results:
            r = hf_results[pheno]
            parts.append(
                f"{pheno} (rho = {safe_get(r, 'spearman_rho', np.nan):.3f}, AUROC = {safe_get(r, 'auroc', np.nan):.3f})"
            )
    if len(parts) > 0:
        hf_block = "Score performance was preserved across heart failure phenotypes (eTable 4): " + ", ".join(parts) + "."

results_text = f"""
RESULTS

Study Population

From the MIMIC-IV database, we identified {n_total:,} patients meeting inclusion criteria (Figure 1). The cohort had an age of {age_mean_sd} years, with {male_pct_str} male. The median BAN-ADHF score was {ban_median:.0f} (IQR {ban_q1:.0f}-{ban_q3:.0f}). Patients were distributed across risk categories as follows: low risk (score <=7) {risk['Low']['n']:,} ({risk['Low']['pct']*100:.1f}%), moderate risk (8-12) {risk['Moderate']['n']:,} ({risk['Moderate']['pct']*100:.1f}%), and high risk (>=13) {risk['High']['n']:,} ({risk['High']['pct']*100:.1f}%) (Table 1).

In this ICU cohort, {fmt_pct(cs_pct)} presented with cardiogenic shock and {fmt_pct(vent_pct)} required invasive mechanical ventilation. The prevalence of chronic kidney disease was {fmt_pct(ckd_pct)}. Median NT-proBNP was {nt_median:,.0f} pg/mL (IQR {nt_q1:,.0f}-{nt_q3:,.0f}), when available.

Primary Outcome: 24-Hour Diuretic Efficiency

Among {n_24:,} patients with valid 24-hour diuretic efficiency data, the BAN-ADHF score demonstrated an inverse association with diuretic efficiency (Spearman rho = {rho_24:.3f}, 95% CI {fmt_ci(rho_24_ci, decimals=3)}, p < 0.001; Pearson r = {pearson_24:.3f}, 95% CI {fmt_ci(pearson_24_ci, decimals=3)}, p < 0.001) (Figure 2, eFigure 3). The C-index was {c_index:.3f} (95% CI {c_index_ci[0]:.3f}-{c_index_ci[1]:.3f}).

For binary discrimination, the AUROC for identifying patients in the lowest efficiency quintile was {auroc_q20:.3f} (95% CI {auroc_q20_ci[0]:.3f}-{auroc_q20_ci[1]:.3f}), and for the lowest quartile was {auroc_q25:.3f} (95% CI {auroc_q25_ci[0]:.3f}-{auroc_q25_ci[1]:.3f}) (Table 2).

Risk Stratification

Median 24-hour diuretic efficiency demonstrated a clear gradient across risk categories (Table 2, Figure 2A). Patients in the low-risk category achieved a median efficiency of {eff_24h_by_risk['Low']['median']:.1f} mL/mg (IQR {eff_24h_by_risk['Low']['q1']:.1f}-{eff_24h_by_risk['Low']['q3']:.1f}), compared with {eff_24h_by_risk['Moderate']['median']:.1f} mL/mg (IQR {eff_24h_by_risk['Moderate']['q1']:.1f}-{eff_24h_by_risk['Moderate']['q3']:.1f}) for moderate risk and {eff_24h_by_risk['High']['median']:.1f} mL/mg (IQR {eff_24h_by_risk['High']['q1']:.1f}-{eff_24h_by_risk['High']['q3']:.1f}) for high risk (p < 0.001). This corresponds to a {fold_diff:.1f}-fold difference in median efficiency from low to high risk, or a {pct_reduction:.0f}% reduction relative to low risk.

Secondary Outcomes

72-Hour Diuretic Efficiency: Among {n_72:,} patients with an ICU stay of at least 72 hours, the inverse association remained (Spearman rho = {rho_72:.3f}, 95% CI {fmt_ci(rho_72_ci, decimals=3)}, p < 0.001), with AUROC for the lowest efficiency quintile of {auroc72_q20:.3f} (95% CI {auroc72_q20_ci[0]:.3f}-{auroc72_q20_ci[1]:.3f}).

Diuretic Resistance: Diuretic resistance (24-hour urine output <=3000 mL) occurred in {dr_prev:.1f}% of patients. The BAN-ADHF score demonstrated modest discrimination (AUROC {dr_auroc:.3f}, 95% CI {dr_auroc_ci[0]:.3f}-{dr_auroc_ci[1]:.3f}).

In-Hospital Mortality: Overall mortality was {mort_rate:.1f}% ({mort_deaths} of {mort_n:,}). The BAN-ADHF score showed modest discrimination for mortality (AUROC {mort_auroc:.3f}, 95% CI {mort_auroc_ci[0]:.3f}-{mort_auroc_ci[1]:.3f}), with higher AUROC in the cardiogenic shock subgroup (AUROC {mort_auroc_cs:.3f}). Mortality increased across risk categories (Table 2).

Subgroup Analyses

The inverse association between BAN-ADHF score and 24-hour diuretic efficiency was consistent across pre-specified subgroups (Figure 3). {interaction_block}

Sensitivity Analyses

Results were robust across sensitivity analyses (eTable 2). Exclusion of outliers (>99th percentile efficiency), exclusion of cardiogenic shock, and restriction to patients with complete left ventricular ejection fraction data yielded similar estimates. Exclusion of patients with advanced chronic kidney disease attenuated the association, which is expected given that creatinine is a score component and is independently associated with diuretic response.

HF Phenotype Analysis

{hf_block}

Comparison with Literature

Our findings align with prior validation studies (Table 3). In this ICU cohort, binary discrimination for diuretic resistance and mortality was modest, which is consistent with outcome prevalence and the information loss introduced by dichotomization.
""".strip()

print(results_text)

#==========================================================================
# KEY STATISTICS SUMMARY
#==========================================================================

print("\n" + "="*70)
print("KEY STATISTICS FOR QUICK REFERENCE")
print("="*70)

print(f"""
COHORT:
  N total: {n_total:,}
  N 24h analysis: {n_24:,}
  N 72h analysis: {n_72:,}
  Age: {age_mean_sd} years
  Male: {male_pct_str}

SCORE DISTRIBUTION:
  Median [IQR]: {ban_median:.0f} [{ban_q1:.0f}-{ban_q3:.0f}]
  Low risk (<=7): {risk['Low']['n']:,} ({risk['Low']['pct']*100:.1f}%)
  Moderate (8-12): {risk['Moderate']['n']:,} ({risk['Moderate']['pct']*100:.1f}%)
  High risk (>=13): {risk['High']['n']:,} ({risk['High']['pct']*100:.1f}%)

PRIMARY OUTCOME (24h Efficiency):
  Spearman rho: {rho_24:.3f} (95% CI {rho_24_ci[0]:.3f}, {rho_24_ci[1]:.3f})
  Pearson r: {pearson_24:.3f} (95% CI {pearson_24_ci[0]:.3f}, {pearson_24_ci[1]:.3f})
  C-index: {c_index:.3f} (95% CI {c_index_ci[0]:.3f}, {c_index_ci[1]:.3f})
  AUROC Q20: {auroc_q20:.3f} (95% CI {auroc_q20_ci[0]:.3f}, {auroc_q20_ci[1]:.3f})
  AUROC Q25: {auroc_q25:.3f} (95% CI {auroc_q25_ci[0]:.3f}, {auroc_q25_ci[1]:.3f})

EFFICIENCY BY RISK (24h):
  Low: {eff_24h_by_risk['Low']['median']:.1f} mL/mg [{eff_24h_by_risk['Low']['q1']:.1f}-{eff_24h_by_risk['Low']['q3']:.1f}]
  Moderate: {eff_24h_by_risk['Moderate']['median']:.1f} mL/mg [{eff_24h_by_risk['Moderate']['q1']:.1f}-{eff_24h_by_risk['Moderate']['q3']:.1f}]
  High: {eff_24h_by_risk['High']['median']:.1f} mL/mg [{eff_24h_by_risk['High']['q1']:.1f}-{eff_24h_by_risk['High']['q3']:.1f}]
  Fold difference (Low vs High): {fold_diff:.1f}-fold

SECONDARY OUTCOMES:
  72h rho: {rho_72:.3f} (95% CI {rho_72_ci[0]:.3f}, {rho_72_ci[1]:.3f})
  DR prevalence: {dr_prev:.1f}%, AUROC {dr_auroc:.3f}
  Mortality: {mort_rate:.1f}%, AUROC {mort_auroc:.3f}
""")

#==========================================================================
# SAVE TEXT FILES
#==========================================================================

with open('/content/manuscript_abstract.txt', 'w') as f:
    f.write(abstract + "\n")

with open('/content/manuscript_results.txt', 'w') as f:
    f.write(results_text + "\n")

print("\n" + "="*70)
print("MANUSCRIPT TEXT COMPLETE")
print("="*70)
print("""
Files generated:
  - manuscript_abstract.txt
  - manuscript_results.txt

Notes:
  - Gender percent is computed robustly (male_sex if available, otherwise gender).
  - Claims that require comprehensive literature review were removed to keep the text defensible.
""")
print("\n-> Next: Package all outputs (Cell 21)")


In [None]:
#==========================================================================
# SECTION 17: PACKAGE ALL OUTPUTS
# Cell 21: Organize files, create checklists, README, and ZIP (ROBUST)
#
# Fixes in this version:
#  - Robust file copy: tries multiple possible filenames (handles your Figure/R export name drift)
#  - No KeyError in README: uses safe getters for dicts + safe CI formatting
#  - Risk table in README uses eff_24h_by_risk if present, otherwise computes on the fly
#  - Fold difference guards against divide-by-zero
#  - "Comparison with Literature" deltas won't crash if values missing. Also avoids hardcoded % deltas
#  - Uses a stable Analysis Date (yyyy-mm-dd)
#==========================================================================

import os
import shutil
import pandas as pd
from datetime import datetime
import numpy as np
import pickle

print("="*70)
print("PACKAGING ALL OUTPUTS (ROBUST)")
print("="*70)

#==========================================================================
# HELPERS
#==========================================================================

def safe_get(d, key, default=np.nan):
    return d.get(key, default) if isinstance(d, dict) else default

def safe_ci(d, key, default=(np.nan, np.nan)):
    ci = safe_get(d, key, default)
    if ci is None:
        return default
    if isinstance(ci, (list, tuple)) and len(ci) == 2:
        return (ci[0], ci[1])
    return default

def fmt_num(x, nd=3):
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return "NA"
    return f"{x:.{nd}f}"

def fmt_ci(ci, nd=3):
    lo, hi = ci
    if lo is None or hi is None:
        return "NA"
    if (isinstance(lo, float) and np.isnan(lo)) or (isinstance(hi, float) and np.isnan(hi)):
        return "NA"
    return f"({lo:.{nd}f}, {hi:.{nd}f})"

def try_copy(possible_names, dest_dir, label_prefix=""):
    """
    possible_names: list of filenames (without /content/)
    Copies the first existing one found.
    Returns (copied_bool, chosen_filename_or_None)
    """
    for name in possible_names:
        src = f"/content/{name}"
        if os.path.exists(src):
            shutil.copy(src, os.path.join(dest_dir, name))
            if label_prefix:
                print(f"  ✓ {label_prefix}/{name}")
            else:
                print(f"  ✓ {name}")
            return True, name
    # nothing found
    missing_display = possible_names[0] if possible_names else "UNKNOWN"
    print(f"  ⚠ Missing: {missing_display}")
    return False, None

def ensure_dir(path):
    os.makedirs(path, exist_ok=True)

#==========================================================================
# CREATE FOLDER STRUCTURE
#==========================================================================

base_dir = '/content/BAN_ADHF_Validation'

# Remove existing folder if present
if os.path.exists(base_dir):
    shutil.rmtree(base_dir)

folders = [
    f'{base_dir}/Tables/Main',
    f'{base_dir}/Tables/Supplementary',
    f'{base_dir}/Figures/Main',
    f'{base_dir}/Figures/Supplementary',
    f'{base_dir}/Figures/R_Data',
    f'{base_dir}/Manuscript',
    f'{base_dir}/Checklists',
    f'{base_dir}/Data'
]

for folder in folders:
    ensure_dir(folder)

print("✓ Folder structure created")

#==========================================================================
# COPY FILES TO APPROPRIATE FOLDERS
#==========================================================================

print("\n" + "-"*70)
print("Copying files to organized folders...")
print("-"*70)

# --- Main Tables (stable names)
main_tables = [
    'Table1_baseline_by_risk.csv',
    'Table1_baseline_by_risk.xlsx',
    'Table2_outcomes_discrimination.csv',
    'Table2_outcomes_discrimination.xlsx',
    'Table3_literature_comparison.csv',
    'Table3_literature_comparison.xlsx'
]
for f in main_tables:
    try_copy([f], f'{base_dir}/Tables/Main', label_prefix="Tables/Main")

# --- Supplementary Tables
supp_tables = [
    'eTable1_baseline_by_mortality.csv',
    'eTable1_baseline_by_mortality.xlsx',
    'eTable2_sensitivity_analyses.csv',
    'eTable2_sensitivity_analyses.xlsx',
    'eTable3_secondary_outcomes.csv',
    'eTable3_secondary_outcomes.xlsx',
    'eTable4_hf_phenotype.csv',
    'eTable4_hf_phenotype.xlsx'
]
for f in supp_tables:
    try_copy([f], f'{base_dir}/Tables/Supplementary', label_prefix="Tables/Supplementary")

# --- Main Figures (handle .tif vs .tiff drift)
main_figures_map = [
    (['Figure2_efficiency_by_risk.png'], 'Figures/Main'),
    (['Figure2_efficiency_by_risk.tiff', 'Figure2_efficiency_by_risk.tif'], 'Figures/Main'),
    (['Figure3_forest_plot.png'], 'Figures/Main'),
    (['Figure3_forest_plot.tiff', 'Figure3_forest_plot.tif'], 'Figures/Main'),
]
for names, subdir in main_figures_map:
    try_copy(names, f'{base_dir}/{subdir}', label_prefix=subdir)

# --- Supplementary Figures
supp_figures_map = [
    (['eFigure1_ROC_curves.png'], 'Figures/Supplementary'),
    (['eFigure1_ROC_curves.tiff', 'eFigure1_ROC_curves.tif'], 'Figures/Supplementary'),
    (['eFigure2_score_distribution.png'], 'Figures/Supplementary'),
    (['eFigure2_score_distribution.tiff', 'eFigure2_score_distribution.tif'], 'Figures/Supplementary'),
    (['eFigure3_scatter_correlation.png'], 'Figures/Supplementary'),
    (['eFigure3_scatter_correlation.tiff', 'eFigure3_scatter_correlation.tif'], 'Figures/Supplementary'),
]
for names, subdir in supp_figures_map:
    try_copy(names, f'{base_dir}/{subdir}', label_prefix=subdir)

# --- R Data files (your earlier section exported these names, but older notebooks sometimes used others)
r_data_candidates = {
    # current names from your Section 14/15
    'data_for_R_figure2_efficiency.csv': ['data_for_R_figure2_efficiency.csv'],
    'data_for_R_figure2_summary.csv': ['data_for_R_figure2_summary.csv'],
    'data_for_R_figure3_forest.csv': ['data_for_R_figure3_forest.csv'],
    'data_for_R_interactions.csv': ['data_for_R_interactions.csv'],
    'data_for_R_subgroups_detail.csv': ['data_for_R_subgroups_detail.csv'],
    'data_for_R_literature_comparison.csv': ['data_for_R_literature_comparison.csv'],
    'data_for_R_roc_q20.csv': ['data_for_R_roc_q20.csv'],
    'data_for_R_score_distribution.csv': ['data_for_R_score_distribution.csv'],
    'data_for_R_scatter.csv': ['data_for_R_scatter.csv'],

    # legacy / alternate names (if they exist, we still bring them in)
    'data_for_R_forest_plot.csv': ['data_for_R_forest_plot.csv', 'data_for_R_figure3_forest.csv'],
    'data_for_R_outcomes_by_risk.csv': ['data_for_R_outcomes_by_risk.csv'],
    'data_for_R_efficiency_raw.csv': ['data_for_R_efficiency_raw.csv', 'data_for_R_figure2_efficiency.csv'],
}

copied_r = set()
for canonical, candidates in r_data_candidates.items():
    copied, chosen = try_copy(candidates, f'{base_dir}/Figures/R_Data', label_prefix="Figures/R_Data")
    if copied and chosen:
        copied_r.add(chosen)

# --- Manuscript text files
manuscript_files = ['manuscript_abstract.txt', 'manuscript_results.txt']
for f in manuscript_files:
    try_copy([f], f'{base_dir}/Manuscript', label_prefix="Manuscript")

#==========================================================================
# CREATE TRIPOD CHECKLIST (Prediction Model Validation)
#==========================================================================

print("\n" + "-"*70)
print("Creating TRIPOD Checklist...")
print("-"*70)

tripod_items = [
    ('1', 'Title', 'Identify the study as developing and/or validating a multivariable prediction model', 'Yes', 'Title includes external validation language'),
    ('2', 'Abstract', 'Provide a summary of objectives, design, setting, participants, sample size, predictors, outcome, analysis, results, and conclusions', 'Yes', 'Structured abstract provided'),
    ('3a', 'Background', 'Explain medical context and rationale', 'Yes', 'Introduction'),
    ('3b', 'Objectives', 'Specify objectives, including development vs validation', 'Yes', 'External validation in ICU'),
    ('4a', 'Source of data', 'Describe study design or data source', 'Yes', 'MIMIC-IV retrospective cohort'),
    ('4b', 'Source of data', 'Specify key study dates', 'Yes', '2008-2019'),
    ('5a', 'Participants', 'Specify key elements of the setting', 'Yes', 'ICU ADHF cohort'),
    ('5b', 'Participants', 'Describe eligibility criteria', 'Yes', 'Methods'),
    ('5c', 'Participants', 'Give details of treatments received, if relevant', 'Yes', 'IV diuretics'),
    ('6a', 'Outcome', 'Clearly define the outcome', 'Yes', '24h diuretic efficiency (mL/mg)'),
    ('6b', 'Outcome', 'Report blinding of outcome assessment', 'N/A', 'Retrospective database'),
    ('7a', 'Predictors', 'Clearly define predictors', 'Yes', 'BAN-ADHF components'),
    ('7b', 'Predictors', 'Report blinding of predictors assessment', 'N/A', 'Retrospective database'),
    ('8', 'Sample size', 'Explain how study size was arrived at', 'Yes', 'All eligible patients'),
    ('9', 'Missing data', 'Describe how missing data were handled', 'Yes', 'Complete-case + sensitivity analyses'),
    ('10a', 'Statistical analysis', 'Describe how predictors were handled', 'Yes', 'Score per original algorithm'),
    ('10b', 'Statistical analysis', 'Specify model-building procedures', 'Yes', 'External validation only'),
    ('10c', 'Statistical analysis', 'For validation, describe how predictions were calculated', 'Yes', 'Score applied without updating'),
    ('10d', 'Statistical analysis', 'Specify performance measures', 'Yes', 'rho, r, C-index, AUROC'),
    ('10e', 'Statistical analysis', 'Describe any model updating', 'N/A', 'No updating'),
    ('11', 'Risk groups', 'Provide details on risk groups', 'Yes', 'Low/Moderate/High thresholds'),
    ('13a', 'Participants', 'Describe participant flow', 'Yes', 'Figure 1'),
    ('13b', 'Participants', 'Describe participant characteristics', 'Yes', 'Table 1'),
    ('13c', 'Participants', 'Compare with development data', 'Yes', 'Table 3'),
    ('14a', 'Model development', 'Specify number of participants and events', 'Yes', 'Reported per outcome'),
    ('14b', 'Model specification', 'Report unadjusted association', 'Yes', 'Correlations'),
    ('15a', 'Model performance', 'Report performance measures with CIs', 'Yes', 'CIs provided'),
    ('15b', 'Model performance', 'Report results with model updating', 'N/A', 'No updating'),
    ('16', 'Model updating', 'Report results from updating', 'N/A', 'No updating'),
    ('17', 'Limitations', 'Discuss limitations', 'Yes', 'Discussion'),
    ('18', 'Interpretation', 'Discuss with reference to other studies', 'Yes', 'Comparison table'),
    ('19a', 'Implications', 'Discuss clinical use', 'Yes', 'Discussion'),
    ('19b', 'Implications', 'Discuss model improvement', 'Yes', 'Future directions'),
    ('20', 'Supplementary information', 'Availability of supplementary resources', 'Yes', 'eTables/eFigures'),
    ('21', 'Funding', 'Source of funding', 'TBD', 'To be added')
]

tripod_df = pd.DataFrame(tripod_items, columns=['Item', 'Section', 'Checklist Item', 'Reported', 'Location/Comment'])
tripod_df.to_csv(f'{base_dir}/Checklists/TRIPOD_checklist.csv', index=False)
tripod_df.to_excel(f'{base_dir}/Checklists/TRIPOD_checklist.xlsx', index=False)
print("✓ TRIPOD checklist created (CSV and XLSX)")

#==========================================================================
# CREATE STROBE CHECKLIST (Observational Study)
#==========================================================================

print("\n" + "-"*70)
print("Creating STROBE Checklist...")
print("-"*70)

strobe_items = [
    ('1a', 'Title and abstract', 'Indicate the study design with a commonly used term', 'Yes', 'Retrospective cohort study'),
    ('1b', 'Title and abstract', 'Provide an informative and balanced summary', 'Yes', 'Structured abstract'),
    ('2', 'Background/rationale', 'Explain scientific background and rationale', 'Yes', 'Introduction'),
    ('3', 'Objectives', 'State specific objectives', 'Yes', 'Primary objective'),
    ('4', 'Study design', 'Present key elements early in the paper', 'Yes', 'Methods'),
    ('5', 'Setting', 'Describe setting, locations, relevant dates', 'Yes', 'MIMIC-IV, 2008-2019'),
    ('6a', 'Participants', 'Eligibility criteria and selection methods', 'Yes', 'Methods'),
    ('7', 'Variables', 'Define outcomes, exposures, predictors, confounders', 'Yes', 'Methods'),
    ('8', 'Data sources', 'Sources and assessment methods', 'Yes', 'Methods'),
    ('9', 'Bias', 'Efforts to address bias', 'Yes', 'Sensitivity analyses'),
    ('10', 'Study size', 'How study size arrived at', 'Yes', 'All eligible'),
    ('11', 'Quantitative variables', 'How quantitative variables handled', 'Yes', 'Methods'),
    ('12a', 'Statistical methods', 'All statistical methods', 'Yes', 'Methods'),
    ('12b', 'Statistical methods', 'Methods for subgroups/interactions', 'Yes', 'Subgroup analyses'),
    ('12c', 'Statistical methods', 'Missing data handling', 'Yes', 'Complete-case + sensitivity'),
    ('12d', 'Statistical methods', 'Loss to follow-up', 'N/A', 'In-hospital outcomes'),
    ('12e', 'Statistical methods', 'Sensitivity analyses', 'Yes', 'eTable 2'),
    ('13a', 'Participants', 'Numbers at each stage', 'Yes', 'Figure 1'),
    ('13b', 'Participants', 'Reasons for non-participation', 'Yes', 'Figure 1'),
    ('13c', 'Participants', 'Flow diagram', 'Yes', 'Figure 1'),
    ('14a', 'Descriptive data', 'Characteristics of participants', 'Yes', 'Table 1'),
    ('14b', 'Descriptive data', 'Missing data counts', 'Yes', 'Tables/eTables'),
    ('15', 'Outcome data', 'Outcome events or summary measures', 'Yes', 'Table 2'),
    ('16a', 'Main results', 'Unadjusted estimates and precision', 'Yes', 'Correlations + CIs'),
    ('16b', 'Main results', 'Category boundaries', 'Yes', 'Risk categories'),
    ('16c', 'Main results', 'Meaningful time periods', 'N/A', 'Prediction at 24h/72h'),
    ('17', 'Other analyses', 'Subgroups/interactions/sensitivity', 'Yes', 'Results'),
    ('18', 'Key results', 'Summarize key results', 'Yes', 'Discussion'),
    ('19', 'Limitations', 'Limitations and bias', 'Yes', 'Discussion'),
    ('20', 'Interpretation', 'Cautious interpretation', 'Yes', 'Discussion'),
    ('21', 'Generalisability', 'Generalisability', 'Yes', 'Discussion'),
    ('22', 'Funding', 'Funding source and role', 'TBD', 'To be added')
]

strobe_df = pd.DataFrame(strobe_items, columns=['Item', 'Section', 'Checklist Item', 'Reported', 'Location/Comment'])
strobe_df.to_csv(f'{base_dir}/Checklists/STROBE_checklist.csv', index=False)
strobe_df.to_excel(f'{base_dir}/Checklists/STROBE_checklist.xlsx', index=False)
print("✓ STROBE checklist created (CSV and XLSX)")

#==========================================================================
# CREATE README FILE (NO KEYERRORS)
#==========================================================================

print("\n" + "-"*70)
print("Creating README file...")
print("-"*70)

analysis_date = datetime.now().strftime('%Y-%m-%d')

# Pull key metrics safely for README
n24 = int(safe_get(results_24h, 'n', 0)) if safe_get(results_24h, 'n', None) is not None else 0
rho24 = safe_get(results_24h, 'spearman_rho', np.nan)
rho24_ci = safe_ci(results_24h, 'spearman_ci')
r24 = safe_get(results_24h, 'pearson_r', np.nan)
r24_ci = safe_ci(results_24h, 'pearson_ci')
cidx = safe_get(results_24h, 'c_index', np.nan)
cidx_ci = safe_ci(results_24h, 'c_index_ci')
auc20 = safe_get(results_24h, 'auroc_quintile', np.nan)
auc20_ci = safe_ci(results_24h, 'auroc_quintile_ci')
auc25 = safe_get(results_24h, 'auroc_quartile', np.nan)
auc25_ci = safe_ci(results_24h, 'auroc_quartile_ci')

# Risk stats (use eff_24h_by_risk if available)
def risk_row(cat, score_label):
    if isinstance(eff_24h_by_risk, dict) and cat in eff_24h_by_risk:
        d = eff_24h_by_risk[cat]
        return (cat, score_label, int(d.get('n', 0)), f"{d.get('median', np.nan):.1f} [{d.get('q1', np.nan):.1f}-{d.get('q3', np.nan):.1f}]")
    # fallback: compute
    if 'risk_category' in df.columns and 'diuretic_efficiency_24h' in df.columns:
        sub = df[(df['risk_category']==cat) & df['diuretic_efficiency_24h'].notna() & (df['diuretic_efficiency_24h']>0)]['diuretic_efficiency_24h']
        if len(sub)==0:
            return (cat, score_label, 0, "NA")
        return (cat, score_label, len(sub), f"{sub.median():.1f} [{sub.quantile(0.25):.1f}-{sub.quantile(0.75):.1f}]")
    return (cat, score_label, 0, "NA")

low_row = risk_row('Low', '≤7')
mod_row = risk_row('Moderate', '8-12')
high_row = risk_row('High', '≥13')

# Fold diff
try:
    fold = (eff_24h_by_risk['Low']['median'] / eff_24h_by_risk['High']['median'])
    fold_str = f"{fold:.1f}-fold" if np.isfinite(fold) else "NA"
except Exception:
    fold_str = "NA"

# Secondary outcomes (safe)
n72 = safe_get(results_72h, 'n', np.nan)
rho72 = safe_get(results_72h, 'spearman_rho', np.nan)
auc72 = safe_get(results_72h, 'auroc_quintile', np.nan)
auc72_ci = safe_ci(results_72h, 'auroc_quintile_ci')

dr_n = safe_get(results_dr, 'n', np.nan)
dr_prev = safe_get(results_dr, 'prevalence', np.nan)
dr_auc = safe_get(results_dr, 'auroc', np.nan)
dr_auc_ci = safe_ci(results_dr, 'auroc_ci')

mort_n = safe_get(results_mortality, 'n', np.nan)
mort_rate = safe_get(results_mortality, 'mortality_rate', np.nan)
mort_auc = safe_get(results_mortality, 'auroc', np.nan)
mort_auc_ci = safe_ci(results_mortality, 'auroc_ci')

# Significant interactions from interaction_data if present, else from results_subgroups dict
sig_interactions_md = ""
if 'interaction_data' in globals() and isinstance(interaction_data, list) and len(interaction_data) > 0:
    sig = [x for x in interaction_data if x.get('significant', False)]
    if len(sig) > 0:
        rows = []
        for x in sig:
            rows.append(f"| {x.get('subgroup','NA')} | {x.get('p_interaction', np.nan):.4f} | ρ = {x.get('rho_group1', np.nan):.3f} vs {x.get('rho_group2', np.nan):.3f} |")
        sig_interactions_md = "\n".join(rows)
else:
    inter = safe_get(results_subgroups, 'interaction_results', {}) or {}
    sig = [(k,v) for k,v in inter.items() if pd.notna(safe_get(v,'p_interaction', np.nan)) and safe_get(v,'p_interaction', np.nan) < 0.05]
    if len(sig) > 0:
        rows = []
        for k,v in sig:
            rows.append(f"| {k} | {safe_get(v,'p_interaction', np.nan):.4f} | ρ = {safe_get(v,'rho_1', np.nan):.3f} vs {safe_get(v,'rho_2', np.nan):.3f} |")
        sig_interactions_md = "\n".join(rows)

if sig_interactions_md == "":
    sig_interactions_md = "| None detected | NA | NA |"

readme_content = f"""# BAN-ADHF Score External Validation in ICU Patients

## Study Information
- **Title**: External Validation of the BAN-ADHF Score for Predicting Diuretic Efficiency in Critically Ill Patients with Acute Decompensated Heart Failure
- **Database**: MIMIC-IV (2008-2019)
- **Analysis Date**: {analysis_date}

---

## Key Findings

### Primary Outcome (24-Hour Diuretic Efficiency)
| Metric | Value | 95% CI |
|--------|-------|--------|
| N | {n24:,} | - |
| Spearman ρ | {fmt_num(rho24, 3)} | {fmt_ci(rho24_ci, 3)} |
| Pearson r | {fmt_num(r24, 3)} | {fmt_ci(r24_ci, 3)} |
| C-index | {fmt_num(cidx, 3)} | {fmt_ci(cidx_ci, 3)} |
| AUROC (Q20) | {fmt_num(auc20, 3)} | {fmt_ci(auc20_ci, 3)} |
| AUROC (Q25) | {fmt_num(auc25, 3)} | {fmt_ci(auc25_ci, 3)} |

### Risk Stratification (24h Efficiency)
| Risk Category | Score | N | Median [IQR] mL/mg |
|---------------|-------|---|-------------------|
| {low_row[0]} | {low_row[1]} | {low_row[2]} | {low_row[3]} |
| {mod_row[0]} | {mod_row[1]} | {mod_row[2]} | {mod_row[3]} |
| {high_row[0]} | {high_row[1]} | {high_row[2]} | {high_row[3]} |

**Fold difference (Low vs High): {fold_str}**

### Secondary Outcomes
| Outcome | N | Value | AUROC (95% CI) |
|---------|---|-------|----------------|
| 72h Efficiency | {n72} | ρ = {fmt_num(rho72, 3)} | {fmt_num(auc72, 3)} {fmt_ci(auc72_ci, 3)} |
| Diuretic Resistance | {dr_n} | {fmt_num(dr_prev, 1)}% | {fmt_num(dr_auc, 3)} {fmt_ci(dr_auc_ci, 3)} |
| Mortality | {mort_n} | {fmt_num(mort_rate, 1)}% | {fmt_num(mort_auc, 3)} {fmt_ci(mort_auc_ci, 3)} |

---

## Folder Structure
```

BAN_ADHF_Validation/
├── README.md
├── Tables/
│   ├── Main/
│   └── Supplementary/
├── Figures/
│   ├── Main/
│   ├── Supplementary/
│   └── R_Data/
├── Manuscript/
├── Checklists/
└── Data/
```

---

## Figures

### Main Figures
- Figure 1: Study flow diagram (to be created in PowerPoint/Illustrator)
- Figure 2: Diuretic efficiency by risk category (A: 24h, B: 72h)
- Figure 3: Subgroup forest plot

### Supplementary Figures
- eFigure 1: ROC curves for binary outcomes (Q20, diuretic resistance, mortality)
- eFigure 2: BAN-ADHF score distribution (by risk category and mortality)
- eFigure 3: Correlation scatter plot (score vs 24h efficiency)

---

## Significant Interactions (p < 0.05)
| Subgroup | p-interaction | Summary |
|----------|---------------|---------|
{sig_interactions_md}

---

## Statistical Software
- Python 3.x with pandas, numpy, scipy, scikit-learn
- R (optional for publication-quality figures)
- BigQuery for MIMIC-IV extraction

---

## Contact
[Add author contact information]

## License
[Add license information]

## Citation
[Add citation once published]
"""

with open(f'{base_dir}/README.md', 'w') as f:
    f.write(readme_content)
print("✓ README.md created")

#==========================================================================
# SAVE ANALYSIS COHORT
#==========================================================================

print("\n" + "-"*70)
print("Saving analysis cohort...")
print("-"*70)

key_vars = [
    'hadm_id', 'age', 'gender', 'ban_adhf_total_score', 'risk_category',
    'creatinine', 'bun', 'ntprobnp', 'dbp', 'total_furosemide_equivalent_mg',
    'hx_atrial_fibrillation', 'hx_hypertension', 'prior_hf_hospitalization_12mo',
    'lvef', 'hf_phenotype', 'hx_diabetes', 'hx_renal_disease',
    'hx_myocardial_infarction', 'hx_stroke', 'hx_copd', 'cci_score',
    'cardiogenic_shock', 'invasive_vent',
    'diuretic_efficiency_24h', 'diuretic_efficiency_72h',
    'urine_output_24h_ml', 'diuretic_resistance',
    'hospital_expire_flag', 'icu_stay_ge_24h', 'icu_stay_ge_72h'
]

available_vars = [v for v in key_vars if v in df.columns]
df_export = df[available_vars].copy()
df_export.to_csv(f'{base_dir}/Data/analysis_cohort.csv', index=False)
print(f"✓ Analysis cohort saved ({len(df_export):,} patients, {len(available_vars)} variables)")

#==========================================================================
# SAVE STORED RESULTS AS PICKLE (for reproducibility)
#==========================================================================

results_to_save = {
    'results_24h': results_24h,
    'results_72h': results_72h,
    'results_dr': results_dr,
    'results_mortality': results_mortality,
    'results_sensitivity': results_sensitivity,
    'results_secondary': results_secondary,
    'results_subgroups': results_subgroups,
    'eff_24h_by_risk': eff_24h_by_risk,
    'eff_72h_by_risk': eff_72h_by_risk
}

with open(f'{base_dir}/Data/analysis_results.pkl', 'wb') as f:
    pickle.dump(results_to_save, f)
print("✓ Analysis results saved (pickle)")

#==========================================================================
# CREATE ZIP FILE
#==========================================================================

print("\n" + "-"*70)
print("Creating ZIP archive...")
print("-"*70)

zip_out = '/content/BAN_ADHF_Validation.zip'
if os.path.exists(zip_out):
    os.remove(zip_out)

shutil.make_archive('/content/BAN_ADHF_Validation', 'zip', '/content', 'BAN_ADHF_Validation')
print(f"✓ ZIP archive created: {zip_out}")

zip_size = os.path.getsize(zip_out) / (1024 * 1024)
print(f"  Size: {zip_size:.2f} MB")

#==========================================================================
# FINAL INVENTORY
#==========================================================================

print("\n" + "="*70)
print("FILE INVENTORY")
print("="*70)

def count_files(directory, extension=None):
    count = 0
    for root, dirs, files in os.walk(directory):
        for f in files:
            if extension is None or f.endswith(extension):
                count += 1
    return count

def list_files(directory):
    for root, dirs, files in os.walk(directory):
        level = root.replace(directory, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f'{indent}{os.path.basename(root)}/')
        subindent = ' ' * 2 * (level + 1)
        for f in sorted(files):
            print(f'{subindent}{f}')

list_files(base_dir)

#==========================================================================
# SUMMARY
#==========================================================================

print("\n" + "="*70)
print("PACKAGING COMPLETE")
print("="*70)

n_csv = count_files(base_dir, '.csv')
n_xlsx = count_files(base_dir, '.xlsx')
n_png = count_files(base_dir, '.png')
n_tiff = count_files(base_dir, '.tiff') + count_files(base_dir, '.tif')
n_txt = count_files(base_dir, '.txt')
n_md = count_files(base_dir, '.md')
n_pkl = count_files(base_dir, '.pkl')

print(f"""
CONTENTS SUMMARY:
  CSV files:   {n_csv}
  XLSX files:  {n_xlsx}
  PNG files:   {n_png}
  TIFF/TIF:    {n_tiff}
  TXT files:   {n_txt}
  MD files:    {n_md}
  PKL files:   {n_pkl}
  ─────────────
  TOTAL:       {n_csv + n_xlsx + n_png + n_tiff + n_txt + n_md + n_pkl} files

DOWNLOAD:
  File: {zip_out}
  Size: {zip_size:.2f} MB
""")

print("\n" + "="*70)
print("✓ ALL OUTPUTS PACKAGED SUCCESSFULLY")
print("="*70)

In [None]:
#==========================================================================
# DOWNLOAD ZIP FILE
#==========================================================================

from google.colab import files

# Download the ZIP file
files.download('/content/BAN_ADHF_Validation.zip')

print("✓ Download initiated - check your browser's download folder")