# Mechanical Power Personalisation for ICU Patients

**End-to-end pipeline**: Data extraction → Preprocessing → MDP construction → Model training → Evaluation

Three progressive strategies:
1. **S1 — Static XGBoost** baseline (admission features)
2. **S2 — Time-Window XGBoost** baseline (trajectory features)
3. **S3 — Conservative Q-Learning (CQL)** offline RL agent

---

In [None]:
# ============================================================
# Cell 1: Environment setup
# ============================================================
import sys, os, warnings
warnings.filterwarnings('ignore')

# Ensure project root is importable
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print(f'Project root: {PROJECT_ROOT}')
print(f'Python: {sys.version}')

CWD: /content
Contents: ['.config', 'sample_data']
Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
Home: /root


In [None]:
# ============================================================
# Cell 2: Install dependencies (run once)
# ============================================================
!pip install -q pandas numpy scikit-learn xgboost torch matplotlib seaborn tqdm

In [None]:
# ============================================================
# Cell 3: Imports
# ============================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm

# Project modules
from config.settings import (
    DATA_DIR, ARTEFACTS_DIR, N_ACTIONS, ACTIONS, ACTION_DELTAS,
    STATE_DIM, REWARD_CONFIG, RL_CONFIG
)
from src.data_extraction import (
    build_cohort, extract_all, load_table
)
from src.preprocessing import preprocess_pipeline
from src.mdp_dataset import (
    build_episodes, episodes_to_arrays, split_episodes,
    StateNormaliser, ALL_STATE_FEATURES
)
from src.models import (
    prepare_xgboost_features_s1, prepare_xgboost_features_s2,
    train_xgboost, CQLAgent, SafetyFilter
)
from src.evaluation import (
    evaluate_classifier, off_policy_evaluation,
    safety_audit, comparison_table
)
from src.explainability import (
    plot_cohort_summary, plot_mp_distribution,
    plot_ventilator_parameters, plot_training_curves,
    plot_feature_importance, plot_policy_comparison,
    plot_reward_analysis, display_clinical_decision,
    plot_action_value_heatmap
)

# Ensure artefacts dir exists
ARTEFACTS_DIR.mkdir(parents=True, exist_ok=True)

print('All modules loaded successfully.')

---
## Phase 1: Data Extraction & Cohort Construction

In [None]:
# ============================================================
# Cell 4: Build study cohort
# ============================================================
print('Building study cohort...')
cohort = build_cohort()
print(f'\nCohort size: {len(cohort)} ICU stays')
print(f'Hospital mortality: {cohort["hospital_mortality"].mean():.1%}')
print(f'Age: {cohort["age"].median():.0f} years (median)')
print(f'ICU LOS: {cohort["icu_los_hours"].median():.0f} hours (median)')
cohort.head()

In [None]:
# ============================================================
# Cell 5: Cohort summary visualisation
# ============================================================
fig = plot_cohort_summary(cohort)
plt.show()

In [None]:
# ============================================================
# Cell 6: Extract all clinical data
# ============================================================
print('Extracting clinical data (ventilator, vitals, labs, anthropometrics)...')
raw_data = extract_all(cohort)

print('\nExtracted data shapes:')
for key, df in raw_data.items():
    print(f'  {key:20s}: {df.shape}')

---
## Phase 2: Preprocessing

In [None]:
# ============================================================
# Cell 7: Run preprocessing pipeline
# ============================================================
print('Running preprocessing pipeline...')
print('  Steps: outlier removal → merging → standardisation → ')
print('         MP calculation → derived variables → hourly resampling')

hourly_data = preprocess_pipeline(raw_data, cohort)

print(f'\nHourly data shape: {hourly_data.shape}')
print(f'ICU stays with data: {hourly_data["icustay_id"].nunique()}')
print(f'Total hours: {len(hourly_data)}')
print(f'\nColumn list:')
print(hourly_data.columns.tolist())
hourly_data.head()

In [None]:
# ============================================================
# Cell 8: Mechanical Power analysis
# ============================================================
fig = plot_mp_distribution(hourly_data, cohort)
plt.show()

mp_stats = hourly_data['mechanical_power'].describe()
print('\nMechanical Power statistics:')
print(mp_stats)

In [None]:
# ============================================================
# Cell 9: Ventilator parameter distributions
# ============================================================
fig = plot_ventilator_parameters(hourly_data)
plt.show()

In [None]:
# ============================================================
# Cell 10: Data quality report
# ============================================================
print('Data Quality Report')
print('=' * 60)
missing = hourly_data.isnull().mean().sort_values(ascending=False)
print('\nMissing data fraction per column:')
for col, frac in missing.items():
    bar = '█' * int(frac * 40)
    print(f'  {col:25s} {frac:5.1%} {bar}')

# MP availability
mp_avail = hourly_data.groupby('icustay_id')['mechanical_power'].apply(
    lambda x: x.notna().mean()
)
print(f'\nMP availability per stay: {mp_avail.mean():.1%} (mean)')
print(f'Stays with >50% MP data: {(mp_avail > 0.5).sum()} / {len(mp_avail)}')

---
## Phase 3: MDP Dataset Construction

In [None]:
# ============================================================
# Cell 11: Build episodes
# ============================================================
print('Building MDP episodes...')
episodes = build_episodes(hourly_data, cohort)
print(f'\nTotal episodes: {len(episodes)}')

if len(episodes) > 0:
    lengths = [len(ep['states']) for ep in episodes]
    mort = [ep['patient_info']['hospital_mortality'] for ep in episodes]
    print(f'Episode lengths: {np.mean(lengths):.1f} ± {np.std(lengths):.1f} steps')
    print(f'Mortality in episodes: {np.mean(mort):.1%}')
    
    # Action distribution
    all_actions = [a for ep in episodes for a in ep['actions']]
    print(f'\nBehaviour policy action distribution:')
    for a in range(N_ACTIONS):
        count = all_actions.count(a)
        frac = count / max(len(all_actions), 1)
        print(f'  {ACTIONS[a]:18s}: {count:5d} ({frac:5.1%})')
else:
    print('WARNING: No episodes constructed. Check data availability.')

In [None]:
# ============================================================
# Cell 12: Reward analysis
# ============================================================
if len(episodes) > 0:
    fig = plot_reward_analysis(episodes)
    plt.show()
else:
    print('No episodes to analyse.')

In [None]:
# ============================================================
# Cell 13: Convert to arrays & split
# ============================================================
if len(episodes) > 0:
    print('Converting episodes to arrays...')
    feature_list = ALL_STATE_FEATURES
    states, actions, rewards, next_states, dones = episodes_to_arrays(
        episodes, feature_list
    )
    print(f'States shape:  {states.shape}')
    print(f'Actions shape: {actions.shape}')
    print(f'Rewards range: [{rewards.min():.1f}, {rewards.max():.1f}]')
    print(f'Terminal steps: {dones.sum()}')
    
    # Patient-level train/val/test split
    train_ep, val_ep, test_ep = split_episodes(episodes)
    print(f'\nSplit: train={len(train_ep)}, val={len(val_ep)}, test={len(test_ep)}')
    
    # Convert splits
    train_s, train_a, train_r, train_ns, train_d = episodes_to_arrays(train_ep, feature_list)
    val_s, val_a, val_r, val_ns, val_d = episodes_to_arrays(val_ep, feature_list)
    test_s, test_a, test_r, test_ns, test_d = episodes_to_arrays(test_ep, feature_list)
    
    # Normalise
    normaliser = StateNormaliser()
    normaliser.fit(train_s)
    train_s_n = normaliser.transform(train_s)
    val_s_n = normaliser.transform(val_s)
    test_s_n = normaliser.transform(test_s)
    train_ns_n = normaliser.transform(train_ns)
    val_ns_n = normaliser.transform(val_ns)
    test_ns_n = normaliser.transform(test_ns)
    
    print(f'\nNormalised state dim: {train_s_n.shape[1]}')
    print('Dataset construction complete.')
else:
    print('Skipping — no episodes available.')

---
## Phase 4: Strategy S1 — Static XGBoost Baseline

In [None]:
# ============================================================
# Cell 14: S1 — Prepare features
# ============================================================
if len(episodes) > 0:
    print('Preparing S1 features (admission snapshot)...')
    
    X_train_s1, y_train_s1, feat_s1 = prepare_xgboost_features_s1(train_ep)
    X_val_s1, y_val_s1, _ = prepare_xgboost_features_s1(val_ep)
    X_test_s1, y_test_s1, _ = prepare_xgboost_features_s1(test_ep)
    
    print(f'S1 features: {len(feat_s1)}')
    print(f'Train: {X_train_s1.shape}, Val: {X_val_s1.shape}, Test: {X_test_s1.shape}')
    print(f'Mortality rate — Train: {y_train_s1.mean():.2%}, Test: {y_test_s1.mean():.2%}')
else:
    print('Skipping — no episodes available.')

In [None]:
# ============================================================
# Cell 15: S1 — Train XGBoost
# ============================================================
if len(episodes) > 0:
    print('Training S1 XGBoost classifier...')
    model_s1 = train_xgboost(
        X_train_s1, y_train_s1, X_val_s1, y_val_s1,
        params={
            'max_depth': 4,
            'learning_rate': 0.05,
            'n_estimators': 200,
            'subsample': 0.8,
            'colsample_bytree': 0.8,
            'scale_pos_weight': max(1, (1 - y_train_s1.mean()) / max(y_train_s1.mean(), 0.01)),
        }
    )
    
    # Evaluate
    y_pred_s1 = model_s1.predict_proba(X_test_s1)[:, 1]
    metrics_s1 = evaluate_classifier(y_test_s1, y_pred_s1)
    print('\nS1 Test Metrics:')
    for k, v in metrics_s1.items():
        if isinstance(v, (int, float)):
            print(f'  {k:20s}: {v:.4f}')
        else:
            print(f'  {k:20s}: {v}')
    
    # Feature importance
    fig = plot_feature_importance(model_s1, feat_s1, top_n=min(15, len(feat_s1)))
    plt.show()
else:
    print('Skipping — no episodes available.')

---
## Phase 5: Strategy S2 — Time-Window XGBoost

In [None]:
# ============================================================
# Cell 16: S2 — Prepare features & train
# ============================================================
if len(episodes) > 0:
    print('Preparing S2 features (multi-timepoint + trajectory stats)...')
    
    X_train_s2, y_train_s2, feat_s2 = prepare_xgboost_features_s2(
        train_ep, feature_list=feature_list, windows=[0, 6, 12, 24]
    )
    X_val_s2, y_val_s2, _ = prepare_xgboost_features_s2(
        val_ep, feature_list=feature_list, windows=[0, 6, 12, 24]
    )
    X_test_s2, y_test_s2, _ = prepare_xgboost_features_s2(
        test_ep, feature_list=feature_list, windows=[0, 6, 12, 24]
    )
    
    print(f'S2 features: {len(feat_s2)}')
    print(f'Train: {X_train_s2.shape}, Val: {X_val_s2.shape}, Test: {X_test_s2.shape}')
    
    print('\nTraining S2 XGBoost classifier...')
    model_s2 = train_xgboost(
        X_train_s2, y_train_s2, X_val_s2, y_val_s2,
        params={
            'max_depth': 5,
            'learning_rate': 0.05,
            'n_estimators': 300,
            'subsample': 0.8,
            'colsample_bytree': 0.7,
            'scale_pos_weight': max(1, (1 - y_train_s2.mean()) / max(y_train_s2.mean(), 0.01)),
        }
    )
    
    y_pred_s2 = model_s2.predict_proba(X_test_s2)[:, 1]
    metrics_s2 = evaluate_classifier(y_test_s2, y_pred_s2)
    print('\nS2 Test Metrics:')
    for k, v in metrics_s2.items():
        if isinstance(v, (int, float)):
            print(f'  {k:20s}: {v:.4f}')
        else:
            print(f'  {k:20s}: {v}')
    
    fig = plot_feature_importance(model_s2, feat_s2, top_n=min(15, len(feat_s2)))
    plt.show()
else:
    print('Skipping — no episodes available.')

---
## Phase 6: Strategy S3 — Conservative Q-Learning (CQL)

In [None]:
# ============================================================
# Cell 17: Initialise CQL Agent
# ============================================================
if len(episodes) > 0:
    from src.models import TORCH_AVAILABLE
    
    if TORCH_AVAILABLE:
        agent = CQLAgent(
            state_dim=STATE_DIM,
            n_actions=N_ACTIONS,
            hidden_dim=RL_CONFIG.get('hidden_dim', 256),
            lr=RL_CONFIG.get('lr', 1e-4),
            gamma=RL_CONFIG.get('gamma', 0.99),
            cql_alpha=RL_CONFIG.get('cql_alpha', 1.0),
            tau=RL_CONFIG.get('tau', 0.005),
            batch_size=RL_CONFIG.get('batch_size', 256),
            dropout=0.3,
        )
        print(f'CQL Agent initialised on {agent.device}')
        print(f'  State dim:  {STATE_DIM}')
        print(f'  Actions:    {N_ACTIONS}')
        print(f'  Hidden dim: {RL_CONFIG.get("hidden_dim", 256)}')
        print(f'  CQL alpha:  {RL_CONFIG.get("cql_alpha", 1.0)}')
        print(f'  LR:         {RL_CONFIG.get("lr", 1e-4)}')
    else:
        print('WARNING: PyTorch not available — skipping CQL agent initialisation.')
        print('Strategies S1 and S2 (XGBoost) will still work.')
else:
    print('Skipping — no episodes available.')

In [None]:
# ============================================================
# Cell 18: Train CQL agent
# ============================================================
if len(episodes) > 0:
    n_epochs = RL_CONFIG.get('n_epochs', 50)
    batch_size = RL_CONFIG.get('batch_size', 256)
    
    print(f'Training CQL for {n_epochs} epochs (batch_size={batch_size})...')
    history = agent.train(
        states=train_s_n,
        actions=train_a,
        rewards=train_r,
        next_states=train_ns_n,
        dones=train_d,
        n_epochs=n_epochs,
        batch_size=batch_size,
    )
    
    print(f'\nFinal loss: {history["loss"][-1]:.4f}')
    print(f'Final Q-mean: {history["q_mean"][-1]:.4f}')
    
    # Save model
    agent.save(ARTEFACTS_DIR / 'cql_agent.pt')
    print(f'Model saved to {ARTEFACTS_DIR / "cql_agent.pt"}')
else:
    print('Skipping — no episodes available.')

In [None]:
# ============================================================
# Cell 19: Training curves
# ============================================================
if len(episodes) > 0:
    fig = plot_training_curves(history)
    plt.show()
else:
    print('Skipping — no training history.')

---
## Phase 7: Evaluation & Off-Policy Estimation

In [None]:
# ============================================================
# Cell 20: CQL – Action distribution on test set
# ============================================================
if len(episodes) > 0:
    print('CQL policy evaluation on test set...')
    
    # Get CQL actions for test transitions
    cql_actions = []
    cql_q_values = []
    for i in range(len(test_s_n)):
        pred = agent.predict(test_s_n[i])
        cql_actions.append(pred['action'])
        cql_q_values.append(pred['q_values'])
    
    cql_actions = np.array(cql_actions)
    
    # Compare action distributions
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    fig.suptitle('Action Distribution: Behaviour vs. CQL Policy', fontsize=14)
    
    for ax, actions_arr, title in [
        (axes[0], test_a, 'Clinician (Behaviour)'),
        (axes[1], cql_actions, 'CQL Policy (Learned)'),
    ]:
        counts = [np.sum(actions_arr == a) for a in range(N_ACTIONS)]
        colors = ['#2ecc71', '#85c1e9', '#95a5a6', '#f5b041', '#c0392b']
        ax.bar(range(N_ACTIONS), counts, color=colors)
        ax.set_xticks(range(N_ACTIONS))
        ax.set_xticklabels([ACTIONS[a] for a in range(N_ACTIONS)], rotation=30, ha='right')
        ax.set_ylabel('Count')
        ax.set_title(title)
    
    plt.tight_layout()
    plt.savefig(ARTEFACTS_DIR / 'action_distributions.png', bbox_inches='tight')
    plt.show()
    
    # Agreement rate
    agreement = (cql_actions == test_a).mean()
    print(f'\nPolicy agreement with clinicians: {agreement:.1%}')
else:
    print('Skipping.')

In [None]:
# ============================================================
# Cell 21: Off-Policy Evaluation
# ============================================================
if len(episodes) > 0:
    print('Running Off-Policy Evaluation (OPE)...')
    ope_results = off_policy_evaluation(
        agent=agent,
        episodes=test_ep,
        feature_list=feature_list,
        gamma=RL_CONFIG.get('gamma', 0.99)
    )
    
    print('\nOPE Results:')
    for k, v in ope_results.items():
        if isinstance(v, float):
            print(f'  {k:30s}: {v:.4f}')
        else:
            print(f'  {k:30s}: {v}')
else:
    print('Skipping.')

In [None]:
# ============================================================
# Cell 22: Safety audit
# ============================================================
if len(episodes) > 0:
    print('Running safety audit...')
    safety_filter = SafetyFilter()
    
    safety_results = safety_audit(
        agent=agent,
        episodes=test_ep,
        feature_list=feature_list,
        safety_filter=safety_filter
    )
    
    print('\nSafety Audit Results:')
    for k, v in safety_results.items():
        if isinstance(v, dict):
            print(f'  {k}:')
            for sk, sv in v.items():
                print(f'    {sk}: {sv}')
        else:
            print(f'  {k:30s}: {v}')
else:
    print('Skipping.')

---
## Phase 8: Strategy Comparison

In [None]:
# ============================================================
# Cell 23: Comparison table
# ============================================================
if len(episodes) > 0:
    results_list = []
    
    # S1
    if 'metrics_s1' in dir():
        s1_entry = {'strategy': 'S1_Static_XGBoost'}
        s1_entry.update(metrics_s1)
        results_list.append(s1_entry)
    
    # S2
    if 'metrics_s2' in dir():
        s2_entry = {'strategy': 'S2_TimeWindow_XGBoost'}
        s2_entry.update(metrics_s2)
        results_list.append(s2_entry)
    
    # S3 – CQL (use OPE metrics)
    if 'ope_results' in dir():
        s3_entry = {'strategy': 'S3_CQL_RL'}
        s3_entry['dm_estimate'] = ope_results.get('dm_estimate', np.nan)
        s3_entry['agreement_rate'] = ope_results.get('agreement_rate', np.nan)
        results_list.append(s3_entry)
    
    if results_list:
        comp_df = pd.DataFrame(results_list).set_index('strategy')
        print('Strategy Comparison')
        print('=' * 70)
        print(comp_df.to_string())
        comp_df.to_csv(ARTEFACTS_DIR / 'strategy_comparison.csv')
        print(f'\nSaved to {ARTEFACTS_DIR / "strategy_comparison.csv"}')
    else:
        print('No results to compare.')
else:
    print('Skipping.')

---
## Phase 9: Policy Visualisation

In [None]:
# ============================================================
# Cell 24: Behaviour vs. learned policy trajectories
# ============================================================
if len(episodes) > 0 and 'agent' in dir():
    fig = plot_policy_comparison(
        episodes=test_ep[:3],
        agent=agent,
        feature_list=feature_list,
        n_episodes=min(3, len(test_ep)),
    )
    plt.show()
else:
    print('Skipping.')

In [None]:
# ============================================================
# Cell 25: Q-value heatmap
# ============================================================
if len(episodes) > 0 and 'agent' in dir():
    fig = plot_action_value_heatmap(agent, test_ep, feature_list=feature_list)
    plt.show()
else:
    print('Skipping.')

In [None]:
# ============================================================
# Cell 26: Clinical decision display (example)
# ============================================================
if len(episodes) > 0 and 'agent' in dir():
    # Pick a sample state from the test set
    sample_ep = test_ep[0]
    sample_state = sample_ep['states'][0]
    
    from src.mdp_dataset import flatten_state
    s = flatten_state(sample_state, feature_list)
    s_n = normaliser.transform(s.reshape(1, -1))[0]
    prediction = agent.predict(s_n)
    
    # Safety check
    safe_action, alert = safety_filter.check(sample_state, prediction['action'])
    if safe_action != prediction['action']:
        alert_msg = f'Action overridden: {ACTIONS[prediction["action"]]} → {ACTIONS[safe_action]} ({alert})'
    else:
        alert_msg = None
    
    display_clinical_decision(sample_state, prediction, safety_alert=alert_msg)
else:
    print('Skipping.')

---
## Phase 10: Summary & Next Steps

In [None]:
# ============================================================
# Cell 27: Pipeline summary
# ============================================================
print('=' * 60)
print('  PIPELINE SUMMARY')
print('=' * 60)

print(f'\nData:')
print(f'  Cohort size:      {len(cohort)} ICU stays')
if len(episodes) > 0:
    print(f'  Episodes built:   {len(episodes)}')
    print(f'  Total transitions: {sum(len(ep["states"]) for ep in episodes)}')
    print(f'  State dimension:  {STATE_DIM}')

print(f'\nClinical:')
print(f'  Mortality rate:   {cohort["hospital_mortality"].mean():.1%}')
if 'mp_stats' in dir():
    print(f'  Median MP:        {mp_stats["50%"]:.1f} J/min')

print(f'\nModels trained:')
if 'model_s1' in dir():
    auroc_s1 = metrics_s1.get("auroc", "N/A")
    if isinstance(auroc_s1, float):
        auroc_s1 = f'{auroc_s1:.4f}'
    print(f'  S1 -- Static XGBoost (AUROC: {auroc_s1})')
if 'model_s2' in dir():
    auroc_s2 = metrics_s2.get("auroc", "N/A")
    if isinstance(auroc_s2, float):
        auroc_s2 = f'{auroc_s2:.4f}'
    print(f'  S2 -- Time-Window XGBoost (AUROC: {auroc_s2})')
if 'agent' in dir():
    print(f'  S3 -- CQL RL Agent (trained {n_epochs} epochs)')

print(f'\nArtefacts saved to: {ARTEFACTS_DIR}')
import os
for f in sorted(os.listdir(ARTEFACTS_DIR)):
    size = os.path.getsize(ARTEFACTS_DIR / f) / 1024
    print(f'  {f:40s} {size:6.1f} KB')

print('\n' + '=' * 60)
print('  Pipeline complete.')
print('=' * 60)

In [None]:
# ============================================================
# Cell 28: Next steps
# ============================================================
print('''
NEXT STEPS FOR PRODUCTION
=========================

1. Scale to full MIMIC-III/IV (>15,000 ventilated ICU stays)
   - Current demo has only 100 patients
   - Adjust min_vent_hours back to 24h
   - Enable proper cross-validation

2. Causal inference additions
   - Add propensity score estimation
   - Implement doubly-robust OPE estimators
   - DAG-based confounding analysis

3. Model improvements
   - Hyperparameter tuning (Optuna/Ray)
   - Ensemble of CQL agents (bagging)
   - Recurrent state encoder (LSTM/Transformer)

4. Clinical validation
   - Expert clinician review of recommendations
   - Subgroup fairness analysis
   - Prospective simulation study

5. Deployment
   - Real-time inference API
   - EHR integration (FHIR/HL7)
   - Monitoring dashboard
''')