 # ***ICU-DISPO***

### Import packages 

In [None]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.pipeline import make_pipeline  
from sklearn.preprocessing import StandardScaler 
from sklearn.ensemble import StackingClassifier 
from scipy.stats import norm
from sklearn.metrics import precision_score, recall_score, roc_auc_score, brier_score_loss
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.calibration import CalibratedClassifierCV  
from sklearn.isotonic import IsotonicRegression  
import warnings
from tqdm import tqdm
import time
from sklearn.utils import resample
import matplotlib.pyplot as plt  
from sklearn.calibration import calibration_curve
import warnings

### Import base learner

In [None]:
from sklearn.linear_model import LogisticRegressionCV, RidgeClassifier, ElasticNetCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from xgboost import XGBClassifier
import lightgbm as lgb

### `data＿loading` 
This function loads a csv. file, splits it into training and testing set (default test size = 0.3), and returns the train and test dictionary with its outcome (Y), treatment (A), covariates (W) and covariates & teatment(W_A) .It also prints a summary of the split, including sample counts and treatment distribution, to help verify balance.

In [None]:
def data_loading(file_path, test_size=0.3, random_state=42):
        df = pd.read_csv(file_path)
        Y = df["Y"].values
        A = df["A"].values
        W = df.drop(columns=["Y", "A"])
        W_A = df.drop(columns=["Y"])
        
        indices = np.arange(len(Y))
        train_idx, test_idx = train_test_split(
            indices, test_size=test_size, random_state=random_state, 
            stratify=A
        )
        
        print(f"📊 Data Split Summary:")
        print(f"   Total samples: {len(Y)}")
        print(f"   Train samples: {len(train_idx)} ({len(train_idx)/len(Y)*100:.1f}%)")
        print(f"   Test samples: {len(test_idx)} ({len(test_idx)/len(Y)*100:.1f}%)")
        print(f"   Train treatment prop: {A[train_idx].mean():.3f}")
        print(f"   Test treatment prop: {A[test_idx].mean():.3f}")
        
        return {
            'train': {
                'Y': Y[train_idx], 'A': A[train_idx], 
                'W': W.iloc[train_idx], 'W_A': W_A.iloc[train_idx],
                'indices': train_idx
            },
            'test': {
                'Y': Y[test_idx], 'A': A[test_idx],
                'W': W.iloc[test_idx], 'W_A': W_A.iloc[test_idx], 
                'indices': test_idx
            }
        }

### `data_preprocessing `
This function performs preprocessing on training (and optionally testing) feature datasets.
It handles missing values and applies standardization (Z-score standardization) so the features are suitable for modeling.
The same scaling is applied consistently to both training and test sets to avoid data leakage.

In [None]:
def data_preprocessing(W_train, W_test=None):

        # Determine W_train type
        if isinstance(W_train, pd.DataFrame):
            median_train = W_train.median()
        else:
            median_train = np.median(W_train, axis=0)
        # Data cleaning: fill NaN values with median
        if isinstance(W_train, pd.DataFrame):
            W_train_clean = W_train.fillna(median_train)
        else:
            W_train_clean = np.where(np.isnan(W_train), median_train, W_train)
        # ======================================================================   
        scaler = StandardScaler()
        W_train_standardized = scaler.fit_transform(W_train_clean)
        # ======================================================================
        if isinstance(W_train, pd.DataFrame):
            W_train_df = pd.DataFrame(W_train_standardized, columns=W_train.columns)
        else:
            W_train_df = pd.DataFrame(W_train_standardized)
        
        if W_test is not None:
            if isinstance(W_test, pd.DataFrame):
                W_test_clean = W_test.fillna(median_train)
            else:
                W_test_clean = np.where(np.isnan(W_test), median_train, W_test)
            W_test_standardized = scaler.transform(W_test_clean)
            if isinstance(W_test, pd.DataFrame):
                W_test_df = pd.DataFrame(W_test_standardized, columns=W_test.columns)
            else:
                W_test_df = pd.DataFrame(W_test_standardized)
            return W_train_df, W_test_df, scaler
        return W_train_df, scaler

### `get_base_learner `
- Linear Models
    - logistic_cv: Logistic regression with cross-validation (roc_auc scoring, balanced class weights).
    - logistic_l1: L1-regularized logistic regression (sparse features).
    - logistic_elastic: Logistic regression with elastic-net regularization (L1+L2).

- Tree-Based Models
    - rf: Random Forest with 150 trees, max depth = 10, class balancing.
    - extra_trees: Extremely randomized trees, shallower than RF.
    - gbm: Gradient Boosting (learning_rate=0.08, subsample=0.8, early stopping).
    - xgb: XGBoost classifier with regularization (reg_alpha, reg_lambda).
    - lgbm: LightGBM with class balancing and shrinkage parameters.

- Non-Linear Models
    - svm_rbf: Support Vector Machine with RBF kernel (scaled features).
    - svm_linear: Linear kernel SVM with regularization C=0.1.
    - mlp: Multi-layer Perceptron (3 hidden layers: 50-30-15, adaptive learning rate, early stopping).

- Simple Models
    - nb: Gaussian Naive Bayes (with feature scaling).
    - knn: k-Nearest Neighbors (k=15, distance weighting).
    - dt: Decision Tree (depth-limited, class-balanced).
return a list of base learner

In [None]:
def get_base_learners():

        base_learners = [
            # ========== LINEAR MODELS ==========
            ('logistic_cv', LogisticRegressionCV(
                cv=5, 
                max_iter=10000, 
                random_state=42,
                solver='lbfgs',
                scoring='roc_auc',
                class_weight='balanced'
            )),
            
            ('logistic_l1', LogisticRegressionCV(
                cv=5, 
                max_iter=10000, 
                penalty='l1',
                solver='liblinear', 
                random_state=42,
                tol=1e-4,
                scoring='roc_auc',
                class_weight='balanced'
            )),
            
            ('logistic_elastic', make_pipeline(
                StandardScaler(),
                LogisticRegressionCV(
                    cv=5,
                    penalty='elasticnet',
                    solver='saga',
                    l1_ratios=[0.1, 0.5, 0.7, 0.9],
                    max_iter=5000,
                    random_state=42,
                    class_weight='balanced'
                )
            )),
            
            # ========== TREE-BASED MODELS ==========
            ('rf', RandomForestClassifier(
                n_estimators=150,
                max_depth=10,
                min_samples_split=15,
                min_samples_leaf=8,
                max_features='sqrt',
                class_weight='balanced',
                random_state=42,
                n_jobs=-1
            )),
            
            ('extra_trees', ExtraTreesClassifier(
                n_estimators=100,
                max_depth=8,
                min_samples_split=20,
                min_samples_leaf=10,
                max_features='sqrt',
                class_weight='balanced',
                random_state=42,
                n_jobs=-1
            )),
            
            ('gbm', GradientBoostingClassifier(
                n_estimators=150,
                max_depth=5,
                learning_rate=0.08,
                subsample=0.8,
                max_features='sqrt',
                random_state=42,
                validation_fraction=0.1,
                n_iter_no_change=10
            )),
            
            ('xgb', XGBClassifier(
                n_estimators=150,
                max_depth=5,
                learning_rate=0.08,
                subsample=0.8,
                colsample_bytree=0.8,
                reg_alpha=0.1,
                reg_lambda=0.1,
                scale_pos_weight=1,
                random_state=42,
                eval_metric='logloss',
                verbosity=0,
                n_jobs=-1
            )),
            
            ('lgbm', lgb.LGBMClassifier(
                n_estimators=150,
                max_depth=5,
                learning_rate=0.08,
                subsample=0.8,
                colsample_bytree=0.8,
                reg_alpha=0.1,
                reg_lambda=0.1,
                random_state=42,
                verbosity=-1,
                n_jobs=-1,
                class_weight='balanced'
            )),
            
            # ========== NON-LINEAR MODELS ==========
            ('svm_rbf', make_pipeline(
                StandardScaler(),
                SVC(
                    probability=True,
                    kernel='rbf',
                    C=1.0,
                    gamma='scale',
                    class_weight='balanced',
                    random_state=42
                )
            )),
            
            ('svm_linear', make_pipeline(
                StandardScaler(),
                SVC(
                    probability=True,
                    kernel='linear',
                    C=0.1,
                    class_weight='balanced',
                    random_state=42
                )
            )),
            
            ('mlp', make_pipeline(
                StandardScaler(),
                MLPClassifier(
                    hidden_layer_sizes=(50, 30, 15),
                    max_iter=3000,
                    alpha=0.01,
                    learning_rate='adaptive',
                    early_stopping=True,
                    validation_fraction=0.1,
                    n_iter_no_change=15,
                    random_state=42
                )
            )),
            
            # ========== SIMPLE MODELS ==========
            ('nb', make_pipeline(
                StandardScaler(),
                GaussianNB()
            )),
            
            ('knn', make_pipeline(
                StandardScaler(),
                KNeighborsClassifier(
                    n_neighbors=15,
                    weights='distance',
                    n_jobs=-1
                )
            )),
            
            ('dt', DecisionTreeClassifier(
                max_depth=6,
                min_samples_split=30,
                min_samples_leaf=15,
                class_weight='balanced',
                random_state=42
            ))
        ]
        return base_learners

### `fit_super_learner` 
This function trains a SuperLearner ensemble model using scikit-learn’s StackingClassifier.
It combines a diverse set of base learners with a meta-learner (logistic regression with cross-validation)
and reports cross-validation(3-fold) performance for monitoring.

In [None]:
def fit_superlearner(X, Y, base_learners, model_name="SuperLearner"):
        print(f"\n << Fitting {model_name} >>")
        
        pbar = tqdm(total=len(base_learners) + 2, desc=f"<< Training {model_name} >>", leave=False)

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
            warnings.filterwarnings("ignore", category=FutureWarning)
            
            meta_learner = LogisticRegressionCV(cv=3, max_iter=5000, 
                                              random_state=42, solver='lbfgs')
            
            pbar.set_description(f"Building {model_name}")
            pbar.update(1)
            
            sl = StackingClassifier(
                estimators=base_learners,
                cv=3,
                stack_method='predict_proba',
                final_estimator=meta_learner,
                n_jobs=-1
            )
            
            try:
                pbar.set_description(f"<< Fitting {model_name} >>   ")
                sl.fit(X, Y)
                pbar.update(1)

                pbar.set_description(f"<< Evaluating {model_name} >>")
                cv_scores = cross_val_score(sl, X, Y, cv=3, scoring='roc_auc')
                print(f"{model_name} Train CV AUC: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
                
            except Exception as e:
                print(f"Warning: {model_name} fitting failed: {str(e)}")
                print("Using simplified model...")
                from sklearn.linear_model import LogisticRegression
                sl = LogisticRegression(max_iter=5000, random_state=42)
                sl.fit(X, Y)
            
            finally:
                pbar.close()
        
        return sl

### `evaluate_calibration` 
This function evaluates the calibration quality of a probabilistic classifier.
It computes the calibration curve, a weighted calibration error, and the Brier Score.
The results indicate whether predicted probabilities are well aligned with observed outcomes.

In [None]:
def evaluate_calibration(y_true, y_prob, n_bins=10):
        
        fraction_of_positives, mean_predicted_value = calibration_curve(
            y_true, y_prob, n_bins=n_bins, strategy='uniform'
        )
        
        
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        # Calibration Error
        calibration_error = 0
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            in_bin = (y_prob > bin_lower) & (y_prob <= bin_upper) 
            prop_in_bin = in_bin.mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = y_true[in_bin].mean()
                avg_confidence_in_bin = y_prob[in_bin].mean()
                calibration_error += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        
        # Brier Score
        brier_score = brier_score_loss(y_true, y_prob)
        
        return {
            'calibration_error': calibration_error,
            'brier_score': brier_score,
            'fraction_of_positives': fraction_of_positives,
            'mean_predicted_value': mean_predicted_value
        }

### `calibrate_classifier` 
This function calibrates a classifier’s probability estimates.
Calibration adjusts these outputs so that predicted probabilities better reflect observed frequencies.
- Platt Scaling (Sigmoid)
    - Fits a logistic regression on the classifier’s decision scores.
    - Maps raw scores → probabilities with a sigmoid function.
    - Works well with limited data.
    - Assumes a logistic relationship, which may be too simple in complex cases.
- Isotonic Regression
    - Fits a piecewise non-decreasing function.
    - Very flexible, can adapt to complex calibration patterns.
    - Requires more data → can overfit with small samples.

In [None]:
def calibrate_classifier(base_model, X_train, y_train, method='platt'):
        
        print(f"   🎯 Calibrating classifier using {method} method...")
        
        if method == 'platt':
            calibrated_model = CalibratedClassifierCV(
                base_model, method='sigmoid', cv=3
            )
        elif method == 'isotonic':
            calibrated_model = CalibratedClassifierCV(
                base_model, method='isotonic', cv=3
            )
        else:
            print(f"   ⚠️  Unknown calibration method: {method}. Using Platt scaling.")
            calibrated_model = CalibratedClassifierCV(
                base_model, method='sigmoid', cv=3
            )
        
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            calibrated_model.fit(X_train, y_train)
        
        return calibrated_model

### `predict_Q_models` 
Compute the Q-model predictions used in causal inference / TMLE:

- Q_A: Predicted outcome using the observed treatment 
- Q_1: Counterfactual prediction if A=1 for everyone.
- Q_0: Counterfactual prediction if A=0 for everyone.

In [None]:
def predict_Q_models(sl, W_A_data, is_test=False):
        
        data_type = "test" if is_test else "train"
        print(f"   Predicting Q models on {data_type} set...")
        # Predict Q_A
        Q_A = sl.predict_proba(W_A_data)[:, 1]
        # Predict Q_1 
        W_A1 = W_A_data.copy()
        W_A1["A"] = 1  
        Q_1 = sl.predict_proba(W_A1)[:, 1]
        # Predict Q_0
        W_A0 = W_A_data.copy()
        W_A0["A"] = 0
        Q_0 = sl.predict_proba(W_A0)[:, 1]
        
        return Q_A, Q_1, Q_0

### `estimate_g_with_calibration` 
This function estimates propensity scores (g-model) with the following pipeline:

- Downsampling to balance treated vs. control in the training set.
- Base g-model training using a SuperLearner ensemble.
- Calibration (Platt scaling, isotonic regression, or other method).
- Evaluation on both the downsampled training set, the full training set, and the test set.
- Overlap weighting computation for both training and test data.
- Visualization and diagnostics of score distributions.
- Returns the calibrated model, diagnostic metrics, and processed weights.

In [None]:
def estimate_g_with_calibration(A_train, W_train_standardized, A_test, W_test_standardized, base_learners, calibration_method='platt'):
    
        print(f"\n << Estimating propensity scores with {calibration_method} calibration >>")
        
        # Downsampling to balance treated vs. control in the training set
        treated_idx = A_train == 1
        control_idx = A_train == 0
        W_treated = W_train_standardized[treated_idx]
        W_control = W_train_standardized[control_idx]
        A_treated = A_train[treated_idx]
        A_control = A_train[control_idx]

        if len(W_treated) > len(W_control):
            W_treated_down = resample(W_treated, replace=False, n_samples=len(W_control), random_state=42)
            A_treated_down = resample(A_treated, replace=False, n_samples=len(W_control), random_state=42)

            W_down = np.vstack([W_treated_down, W_control])
            A_down = np.concatenate([A_treated_down, A_control])
        else:
            W_control_down = resample(W_control, replace=False, n_samples=len(W_treated), random_state=42)
            A_control_down = resample(A_control, replace=False, n_samples=len(W_treated), random_state=42)

            W_down = np.vstack([W_treated, W_control_down])
            A_down = np.concatenate([A_treated, A_control_down])

        print(f"g-model training samples (downsampled): {W_down.shape}, A=1 proportion: {np.mean(A_down):.2f}")

        # Taining the base propensity score model
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            base_g_model = fit_superlearner(W_down, A_down, base_learners, "Base Propensity Score Model")

        # calibrate g-model
        print("   📊 Evaluating calibration before calibration...")
        base_probs_train = base_g_model.predict_proba(W_down)[:, 1] 
        pre_cal_metrics = evaluate_calibration(A_down, base_probs_train) 
        print(f"   Pre-calibration: CE={pre_cal_metrics['calibration_error']:.4f}, Brier={pre_cal_metrics['brier_score']:.4f}")
        
        calibrated_g_model = calibrate_classifier(base_g_model, W_down, A_down, calibration_method)
        
        # evaluate calibration after calibration
        cal_probs_train = calibrated_g_model.predict_proba(W_down)[:, 1]
        post_cal_metrics = evaluate_calibration(A_down, cal_probs_train)
        print(f"   Post-calibration: CE={post_cal_metrics['calibration_error']:.4f}, Brier={post_cal_metrics['brier_score']:.4f}")
        
        improvement_ce = pre_cal_metrics['calibration_error'] - post_cal_metrics['calibration_error']
        improvement_brier = pre_cal_metrics['brier_score'] - post_cal_metrics['brier_score']
        print(f"   📈 Improvement: CE Δ={improvement_ce:+.4f}, Brier Δ={improvement_brier:+.4f}")

        # evaluate on full training set
        print("   📊 Evaluating calibrated model on full training set...")
        g_w_train_full = calibrated_g_model.predict_proba(W_train_standardized)[:, 1]
        train_cal_metrics = evaluate_calibration(A_train, g_w_train_full)
        
        # Predict propensity scores on test set
        print("   Predicting calibrated propensity scores on test set...")
        g_w_test = calibrated_g_model.predict_proba(W_test_standardized)[:, 1]

        # evaluate calibration on test set
        test_cal_metrics = evaluate_calibration(A_test, g_w_test)
        print(f"   Test set calibration: CE={test_cal_metrics['calibration_error']:.4f}, Brier={test_cal_metrics['brier_score']:.4f}")

        
        #  overlap weights for training set
        g_w_train_trimmed = np.clip(g_w_train_full, 0.05, 0.95)
        overlap_weights_train = g_w_train_trimmed * (1 - g_w_train_trimmed)
        H_overlap_train = A_train * (1 - g_w_train_trimmed) - (1 - A_train) * g_w_train_trimmed
        H_1_overlap_train = (1 - g_w_train_trimmed)
        H_0_overlap_train = g_w_train_trimmed
        
        # overlap weights for testing set
        g_w_test_trimmed = np.clip(g_w_test, 0.05, 0.95)
        overlap_weights_test = g_w_test_trimmed * (1 - g_w_test_trimmed)
        H_overlap_test = A_test * (1 - g_w_test_trimmed) - (1 - A_test) * g_w_test_trimmed
        H_1_overlap_test = (1 - g_w_test_trimmed)
        H_0_overlap_test = g_w_test_trimmed

        # 7. DIAGNOSTIC INFORMATION
        print(f"\n📊 Training Set PS Performance:")
        print(f"   Propensity Score: min={g_w_train_full.min():.4f}, max={g_w_train_full.max():.4f}, mean={g_w_train_full.mean():.4f}")
        print(f"   Overlap Weights: min={overlap_weights_train.min():.4f}, max={overlap_weights_train.max():.4f}, mean={overlap_weights_train.mean():.4f}")
        train_good_overlap = np.sum((g_w_train_trimmed >= 0.1) & (g_w_train_trimmed <= 0.9))
        print(f"   Good overlap samples (0.1 ≤ PS ≤ 0.9): {train_good_overlap} ({train_good_overlap/len(g_w_train_trimmed)*100:.1f}%)")

        
        print(f"   [Average after downsample+overlap+calibration] g_w_train_trimmed mean: {g_w_train_trimmed.mean():.4f}")
        print(f"   [Average after downsample+overlap+calibration] overlap_weights_train mean: {overlap_weights_train.mean():.4f}")

        # Detailed training set g-model distribution plot
        import seaborn as sns
        plt.figure(figsize=(8,5))
        # g-model probability distribution for testing set: treated/control
        treated_scores = g_w_train_trimmed[A_train == 1]
        control_scores = g_w_train_trimmed[A_train == 0]
        sns.histplot(treated_scores, bins=30, color='royalblue', label='Treated', kde=True, stat='density', alpha=0.6)
        sns.histplot(control_scores, bins=30, color='orange', label='Control', kde=True, stat='density', alpha=0.6)
        # Adding mean/median
        plt.axvline(g_w_train_trimmed.mean(), color='green', linestyle='--', label=f'Mean: {g_w_train_trimmed.mean():.2f}')
        plt.axvline(np.median(g_w_train_trimmed), color='red', linestyle=':', label=f'Median: {np.median(g_w_train_trimmed):.2f}')
        plt.title('Training Set Propensity Score Distribution (g_w_train_trimmed)')
        plt.xlabel('Propensity Score')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()
        plt.show()

        print(f"\n📊 Test Set PS Performance:")
        print(f"   Propensity Score: min={g_w_test.min():.4f}, max={g_w_test.max():.4f}, mean={g_w_test.mean():.4f}")
        print(f"   Overlap Weights: min={overlap_weights_test.min():.4f}, max={overlap_weights_test.max():.4f}, mean={overlap_weights_test.mean():.4f}")
        test_good_overlap = np.sum((g_w_test_trimmed >= 0.1) & (g_w_test_trimmed <= 0.9))
        print(f"   Good overlap samples (0.1 ≤ PS ≤ 0.9): {test_good_overlap} ({test_good_overlap/len(g_w_test_trimmed)*100:.1f}%)")
        print(f"   [Average after overlap+calibration] g_w_test_trimmed mean: {g_w_test_trimmed.mean():.4f}")
        print(f"   [Average after overlap+calibration] overlap_weights_test mean: {overlap_weights_test.mean():.4f}")

        # g-model probability distribution for testing set:treated/control
        plt.figure(figsize=(8,5))
        treated_scores_test = g_w_test_trimmed[A_test == 1]
        control_scores_test = g_w_test_trimmed[A_test == 0]
        sns.histplot(treated_scores_test, bins=30, color='royalblue', label='Treated', kde=True, stat='density', alpha=0.6)
        sns.histplot(control_scores_test, bins=30, color='orange', label='Control', kde=True, stat='density', alpha=0.6)
        plt.axvline(g_w_test_trimmed.mean(), color='green', linestyle='--', label=f'Mean: {g_w_test_trimmed.mean():.2f}')
        plt.axvline(np.median(g_w_test_trimmed), color='red', linestyle=':', label=f'Median: {np.median(g_w_test_trimmed):.2f}')
        plt.title('Test Set Propensity Score Distribution (g_w_test_trimmed)')
        plt.xlabel('Propensity Score')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()
        plt.show()

        # In order to compare with the original model's test performance, we also evaluate the base model on the test set
        base_probs_test = base_g_model.predict_proba(W_test_standardized)[:, 1]
        base_test_metrics = evaluate_calibration(A_test, base_probs_test)

        
        calibration_info = {
            'pre_calibration_metrics': pre_cal_metrics,
            'post_calibration_metrics': post_cal_metrics,
            'train_calibration_metrics': train_cal_metrics,  
            'test_calibration_metrics': test_cal_metrics,
            'base_test_metrics': base_test_metrics,
            'calibration_method': calibration_method,
            # Information for training set
            'train_ps_scores': g_w_train_full,
            'train_overlap_weights': overlap_weights_train,
            # Mean values for training and testing set
            'g_w_train_trimmed_mean': g_w_train_trimmed.mean(),
            'overlap_weights_train_mean': overlap_weights_train.mean(),
            'g_w_test_trimmed_mean': g_w_test_trimmed.mean(),
            'overlap_weights_test_mean': overlap_weights_test.mean()
        }
        
        return {
            'model': calibrated_g_model,
            'test': {
                'g_w': g_w_test_trimmed,
                'H_1': H_1_overlap_test,
                'H_0': H_0_overlap_test,
                'H_overlap': H_overlap_test,
                'overlap_weights': overlap_weights_test
            },
            'train': {  # Information for training set
                'g_w': g_w_train_trimmed,
                'H_1': H_1_overlap_train,
                'H_0': H_0_overlap_train,
                'H_overlap': H_overlap_train,
                'overlap_weights': overlap_weights_train
            },
            'calibration_info': calibration_info
        }

### ` estimate_fluctuation_param `

This function estimates the fluctuation parameter ϵ in Targeted Maximum Likelihood Estimation (TMLE) using Overlap Weighting.

The idea is to slightly fluctuate (update) the initial outcome regression Q_A by solving a logistic regression with an offset.

In [None]:
def estimate_fluctuation_param(Y, Q_A, H_1, H_0, A, H_overlap=None):
        print("📈 Estimating fluctuation parameter (Overlap Weighting)...")
        
        Q_A_clipped = np.clip(Q_A, 1e-6, 1 - 1e-6)
        logit_QA = np.log(Q_A_clipped / (1 - Q_A_clipped))

        if H_overlap is not None:
            H_A = H_overlap
        else:
            H_A = A * H_1 - (1 - A) * H_0
            
        H_A = H_A.reshape(-1, 1) if H_A.ndim == 1 else H_A

        try:
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore")
                model = sm.GLM(Y, H_A, offset=logit_QA, family=sm.families.Binomial()).fit()
                eps = model.params[0]
                print(f"   Fluctuation parameter (epsilon): {eps:.6f}")
        except Exception as e:
            print(f"   GLM fitting failed, using fallback method: {str(e)}")
            eps = 0.0
            
        return eps

### `update_Q `

Update Q values using the fluctuation parameter

In [None]:
def update_Q(Q_base, H, eps):
        
        Q_clipped = np.clip(Q_base, 1e-6, 1 - 1e-6)
        logit_Q = np.log(Q_clipped / (1 - Q_clipped))
        updated_Q = 1 / (1 + np.exp(-(logit_Q + eps * H)))
        return np.clip(updated_Q, 1e-6, 1 - 1e-6)

### `compute_tmle`

This function computes **Targeted Maximum Likelihood Estimation (TMLE)** estimates for the **Average Treatment Effect (ATE)** and **Average Treatment effect on the Treated (ATT)**.  
It supports both the traditional TMLE formulation and an **overlap weighting** variant.

- ##### Function Purpose
    - Compute **ATE** and **ATT** using updated outcome regression models (`Q_*_update`).
    - Construct the **influence function (IF)** for variance estimation.
    - Estimate **standard error (SE)**, **95% confidence interval (CI)**, and **p-value**.
    - Handle both **with overlap weighting** and **without overlap weighting** cases.

In [None]:
def compute_tmle(Y, A, Q_A_update, Q_1_update, Q_0_update, H_1, H_0, overlap_weights=None, H_overlap=None):
        print("🎯 Computing TMLE estimates (ATE, ATT)...")
        
        if overlap_weights is not None:
            # ATE with Overlap Weighting
            ate_numerator = np.mean((Q_1_update - Q_0_update) * overlap_weights)
            ate_denominator = np.mean(overlap_weights)
            ate = ate_numerator / ate_denominator
            
            # ATT
            treated_idx = (A == 1)
            if np.any(treated_idx):
                att_numerator = np.mean((Q_1_update[treated_idx] - Q_0_update[treated_idx]) * overlap_weights[treated_idx])
                att_denominator = np.mean(overlap_weights[treated_idx])
                att = att_numerator / att_denominator if att_denominator > 0 else np.nan
            else:
                att = np.nan

            # Influence function
            if H_overlap is not None:
                term1 = H_overlap * (Y - Q_A_update)
            else:
                term1 = (A * H_1 - (1 - A) * H_0) * (Y - Q_A_update)
                
            term2 = ((Q_1_update - Q_0_update) * overlap_weights - ate * overlap_weights) / ate_denominator
            infl_fn = term1 + term2 - ate
            
        else:
            # Fallback: Original ATE & ATT no overlap weighting
            ate = np.mean(Q_1_update - Q_0_update)
            att = np.mean((Q_1_update - Q_0_update)[A == 1]) if np.any(A == 1) else np.nan
            
            H_A = A * H_1 - (1 - A) * H_0
            infl_fn = H_A * (Y - Q_A_update) + (Q_1_update - Q_0_update) - ate
        
        # Calculating standard error, 95% CI, p-value
        se = np.sqrt(np.var(infl_fn) / len(Y))
        ci_low = ate - 1.96 * se
        ci_high = ate + 1.96 * se
        p_value = 2 * (1 - norm.cdf(abs(ate / se))) if se > 0 else 1.0
        
        return ate, se, ci_low, ci_high, p_value, infl_fn, att

### `diagnostic_checks`
This function performs comprehensive diagnostic checks for both the Q-model (outcome regression) and the G-model (propensity score model). It summarizes treatment group balance, predictive performance, calibration quality, and overlap diagnostics for either the training or test dataset.

In [None]:
def diagnostic_checks(Y, A, W, Q_A, Q_1, Q_0, g_w, stage="First", is_test=False, model_performance=None):
        
        data_type = "Test" if is_test else "Train"
        print(f"\n=== {stage} {data_type} Diagnostic Checks ===")
        # ========= Treatment group proportion ==========
        treatment_prop = A.mean()
        print(f"{data_type} Treatment group proportion: {treatment_prop:.4f}")

        g_treated = g_w[A==1].mean()
        g_control = g_w[A==0].mean()
        g_overlap = np.minimum(g_w, 1-g_w).mean()

        print(f"{data_type} Propensity Score: Treated Mean: {g_treated:.4f}")
        print(f"{data_type} Propensity Score: Control Mean: {g_control:.4f}")
        print(f"{data_type} Overlap Measure: {g_overlap:.4f} (Higher is better)")

        # ========= Q-model performance metrics ==========
        y_pred = (Q_A >= 0.5).astype(int)
        auc = roc_auc_score(Y, Q_A)
        precision = precision_score(Y, y_pred, zero_division=0)
        recall = recall_score(Y, y_pred, zero_division=0)
        
        print(f"{data_type} Q model AUC: {auc:.4f}")
        print(f"{data_type} Q model Precision: {precision:.4f}")
        print(f"{data_type} Q model Recall: {recall:.4f}")
        if precision + recall > 0:
            f1 = 2 * (precision * recall) / (precision + recall)
            print(f"{data_type} Q model F1 Score: {f1:.4f}")
        else:
            f1 = 0.0

        #========== G-model (propensity score) performance metrics ==========
        print(f"\n--- G-model Performance ---")
        g_pred = (g_w >= 0.5).astype(int)
        g_auc = roc_auc_score(A, g_w)
        g_precision = precision_score(A, g_pred, zero_division=0)
        g_recall = recall_score(A, g_pred, zero_division=0)
        g_brier = brier_score_loss(A, g_w)
        
        print(f"{data_type} G model AUC: {g_auc:.4f}")
        print(f"{data_type} G model Precision: {g_precision:.4f}")
        print(f"{data_type} G model Recall: {g_recall:.4f}")
        if g_precision + g_recall > 0:
            g_f1 = 2 * (g_precision * g_recall) / (g_precision + g_recall)
            print(f"{data_type} G model F1 Score: {g_f1:.4f}")
        else:
            g_f1 = 0.0
        print(f"{data_type} G model Brier Score: {g_brier:.4f}")
        # ========== G-model Calibration Metrics ==========
        print(f"\n--- Calibration Metrics ---")
        if model_performance is not None:
            if is_test and 'test_calibration_metrics' in model_performance:
                cal_metrics = model_performance['test_calibration_metrics']
                print(f"{data_type} G model Calibration Error: {cal_metrics['calibration_error']:.4f}")
                print(f"{data_type} G model Brier Score: {cal_metrics['brier_score']:.4f}")
            elif not is_test and 'train_calibration_metrics' in model_performance:
                cal_metrics = model_performance['train_calibration_metrics']
                print(f"{data_type} G model Calibration Error: {cal_metrics['calibration_error']:.4f}")
                print(f"{data_type} G model Brier Score: {cal_metrics['brier_score']:.4f}")
        # ========= Q-model Distribution ==========
        print(f"{data_type} Q_A distribution: min={Q_A.min():.4f}, max={Q_A.max():.4f}, mean={Q_A.mean():.4f}")
        print(f"{data_type} Q_1 distribution: min={Q_1.min():.4f}, max={Q_1.max():.4f}, mean={Q_1.mean():.4f}")
        print(f"{data_type} Q_0 distribution: min={Q_0.min():.4f}, max={Q_0.max():.4f}, mean={Q_0.mean():.4f}")
        # ========= Raw ATE Estimate ==========
        raw_ate = np.mean(Q_1 - Q_0)
        print(f"{data_type} {stage} Raw ATE estimate: {raw_ate:.6f}")
        # ========= Extreme Propensity Scores ==========
        extreme_ps = np.sum((g_w < 0.05) | (g_w > 0.95))
        print(f"{data_type} Extreme propensity score samples: {extreme_ps} ({extreme_ps/len(g_w)*100:.2f}%)")

        # ====== returning diagnostic results ======
        diagnostic_results = {
             # ====== Treatment group proportion and overlap metrics ======
            'treatment_prop': treatment_prop,
            'g_treated': g_treated,
            'g_control': g_control,
            'g_overlap': g_overlap,
            # ====== Q-model performance metrics ======
            'q_auc': auc,
            'q_precision': precision,
            'q_recall': recall,
            'q_f1': f1,
            'raw_ate': raw_ate,
            'extreme_ps_count': extreme_ps,
            # ====== G-model performance metrics ======
            'g_auc': g_auc,
            'g_precision': g_precision,
            'g_recall': g_recall,
            'g_f1': g_f1,
            'g_brier': g_brier
        }

        return diagnostic_results

### `print_train_test_comparison`
This function prints a comprehensive comparison of train vs. test performance for both the Q-model (Outcome Model) and G-model (Propensity Score Model), along with overlap/propensity score quality checks, and overfitting assessment.

- Arguments
    - train_metrics: dictionary of performance metrics computed on the training set.
    - test_metrics: dictionary of performance metrics computed on the test set.
    - title: (optional) custom title for the printed report.

- Output
    - Prints a formatted table of metrics for Q-model and G-model.
    - Provides automatic overfitting assessments based on thresholds.
    - Prints calibration evaluation if calibration metrics are provided.

In [None]:
def print_train_test_comparison(train_metrics, test_metrics, title="Model Performance Comparison"):
        
        print(f"\n" + "="*80)
        print(f"                    {title}")
        print("="*80)
        
        print("Q-MODEL (Outcome Model) Performance:")
        print(f"{'Metric':<20} {'Train':<12} {'Test':<12} {'Difference':<12} {'Overfitting?':<12}")
        print("-" * 80)
        
        q_metrics = ['q_auc', 'q_precision', 'q_recall', 'q_f1']
        q_names = ['AUC', 'Precision', 'Recall', 'F1 Score']
        
        for metric, name in zip(q_metrics, q_names):
            train_val = train_metrics.get(metric, 0)
            test_val = test_metrics.get(metric, 0)
            diff = train_val - test_val
            overfitting = "Yes" if diff > 0.05 else "No"
            print(f"{name:<20} {train_val:<12.4f} {test_val:<12.4f} {diff:+12.4f} {overfitting:<12}")
        
        print("\n G-MODEL (Propensity Score Model) Performance:")
        print(f"{'Metric':<20} {'Train':<12} {'Test':<12} {'Difference':<12} {'Overfitting?':<12}")
        print("-" * 80)
        
        g_metrics = ['g_auc', 'g_precision', 'g_recall', 'g_f1', 'g_brier']
        g_names = ['AUC', 'Precision', 'Recall', 'F1 Score', 'Brier Score']
        
        for metric, name in zip(g_metrics, g_names):
            train_val = train_metrics.get(metric, 0)
            test_val = test_metrics.get(metric, 0)
            if metric == 'g_brier':
                diff = test_val - train_val
                overfitting = "Yes" if diff > 0.02 else "No"
            else:
                diff = train_val - test_val
                overfitting = "Yes" if diff > 0.05 else "No"
            print(f"{name:<20} {train_val:<12.4f} {test_val:<12.4f} {diff:+12.4f} {overfitting:<12}")
        
        print("\n Overlap & PS Quality: ")
        print(f"{'Metric':<20} {'Train':<12} {'Test':<12} {'Difference':<12}")
        print("-" * 65)
        
        overlap_metrics = ['g_overlap', 'extreme_ps_count']
        overlap_names = ['Overlap Quality', 'Extreme PS Count']
        
        for metric, name in zip(overlap_metrics, overlap_names):
            train_val = train_metrics.get(metric, 0)
            test_val = test_metrics.get(metric, 0)
            diff = train_val - test_val
            if metric == 'extreme_ps_count':
                print(f"{name:<20} {train_val:<12.0f} {test_val:<12.0f} {diff:+12.0f}")
            else:
                print(f"{name:<20} {train_val:<12.4f} {test_val:<12.4f} {diff:+12.4f}")
        
        # ========= Overfitting Assessment ==========
        print("\n" + "="*80)
        print("                         OVERFITTING ASSESSMENT")
        print("="*80)
        
        # =========Q-model overfitting check =========
        q_auc_diff = train_metrics.get('q_auc', 0) - test_metrics.get('q_auc', 0)
        q_f1_diff = train_metrics.get('q_f1', 0) - test_metrics.get('q_f1', 0)
        
        if q_auc_diff > 0.1 or q_f1_diff > 0.1:
            q_assessment = "🔴 Significant Q-model overfitting detected"
        elif q_auc_diff > 0.05 or q_f1_diff > 0.05:
            q_assessment = "🟡 Moderate Q-model overfitting detected"
        else:
            q_assessment = "🟢 Q-model shows good generalization"
        
        # ========= G-model overfitting check =========
        g_auc_diff = train_metrics.get('g_auc', 0) - test_metrics.get('g_auc', 0)
        g_f1_diff = train_metrics.get('g_f1', 0) - test_metrics.get('g_f1', 0)
        g_brier_diff = test_metrics.get('g_brier', 0) - train_metrics.get('g_brier', 0)
        
        if g_auc_diff > 0.1 or g_f1_diff > 0.1 or g_brier_diff > 0.05:
            g_assessment = "🔴 Significant G-model overfitting detected"
        elif g_auc_diff > 0.05 or g_f1_diff > 0.05 or g_brier_diff > 0.02:
            g_assessment = "🟡 Moderate G-model overfitting detected"
        else:
            g_assessment = "🟢 G-model shows good generalization"
        
        print(q_assessment)
        print(g_assessment)

        # ========= Calibration Assessment ==========
        if 'train_calibration_metrics' in train_metrics and 'test_calibration_metrics' in test_metrics:
            train_ce = train_metrics['train_calibration_metrics']['calibration_error']
            test_ce = test_metrics['test_calibration_metrics']['calibration_error']
            ce_diff = test_ce - train_ce
            
            if ce_diff > 0.05:
                cal_assessment = "🔴 Significant calibration degradation on test set"
            elif ce_diff > 0.02:
                cal_assessment = "🟡 Moderate calibration degradation on test set"
            else:
                cal_assessment = "🟢 Calibration maintains well on test set"
            
            print(cal_assessment)
            print(f"   Train Calibration Error: {train_ce:.4f}")
            print(f"   Test Calibration Error: {test_ce:.4f}")
            print(f"   Degradation: {ce_diff:+.4f}")

### `print_results`
This function prints a comprehensive summary of TMLE (Targeted Maximum Likelihood Estimation) results, integrating effect estimates, calibration assessment, train/test performance comparisons, and interpretation of findings.

In [None]:
def print_results(ate, se, ci_low, ci_high, p_value, raw_ate, train_diagnostics, test_pre_diagnostics, test_post_diagnostics, att, calibration_info=None):
        
        print("\n" + "="*80)
        print("    TMLE Results (Overlap Weighting + Calibrated G-Model + Train/Test Split)")
        print("="*80)
        
        # ======== main results table ========
        print(f"{'Estimand':<12} {'Estimate':<12} {'Std.Err':<10} {'95% CI':<25} {'P-value':<10} {'Significant':<12}")
        print("-" * 80)
        print(f"{'ATE':<12} {ate:<12.6f} {se:<10.6f} [{ci_low:.6f}, {ci_high:.6f}] {p_value:<10.6f} {'Yes' if p_value < 0.05 else 'No':<12}")
        
        if not np.isnan(att):
            print(f"{'ATT':<12} {att:<12.6f} {'---':<10} {'---':<25} {'---':<10} {'---':<12}")
        else:
            print(f"{'ATT':<12} {'N/A':<12} {'---':<10} {'---':<25} {'---':<10} {'---':<12}")

        # ======== Calibration information ========
        if calibration_info is not None:
            print("\n" + "="*80)
            print("                      CALIBRATION ASSESSMENT")
            print("="*80)
            print(f"📊 Calibration Method: {calibration_info['calibration_method'].title()}")
            print("\n" + "-"*60)
            print("            Calibration Metrics Comparison (Downsampled Training)")
            print("-"*60)
            print(f"{'Metric':<20} {'Before':<12} {'After':<12} {'Improvement':<12}")
            print("-"*60)
            
            pre_ce = calibration_info['pre_calibration_metrics']['calibration_error']
            post_ce = calibration_info['post_calibration_metrics']['calibration_error']
            ce_improvement = pre_ce - post_ce
            
            pre_brier = calibration_info['pre_calibration_metrics']['brier_score']
            post_brier = calibration_info['post_calibration_metrics']['brier_score']
            brier_improvement = pre_brier - post_brier
            
            print(f"{'Calibration Error':<20} {pre_ce:<12.4f} {post_ce:<12.4f} {ce_improvement:+.4f}")
            print(f"{'Brier Score':<20} {pre_brier:<12.4f} {post_brier:<12.4f} {brier_improvement:+.4f}")

            # ======== Complete training and testing set calibration performance comparison ========    
            print("\n" + "-"*60)
            print("           Full Train vs Test Set Calibration Performance")
            print("-"*60)
            train_ce = calibration_info['train_calibration_metrics']['calibration_error']
            train_brier = calibration_info['train_calibration_metrics']['brier_score']
            test_ce = calibration_info['test_calibration_metrics']['calibration_error']
            test_brier = calibration_info['test_calibration_metrics']['brier_score']
            
            print(f"{'Train Set CE':<20} {train_ce:<12.4f}")
            print(f"{'Test Set CE':<20} {test_ce:<12.4f}")
            print(f"{'CE Degradation':<20} {test_ce - train_ce:+12.4f}")
            print(f"{'Train Set Brier':<20} {train_brier:<12.4f}")
            print(f"{'Test Set Brier':<20} {test_brier:<12.4f}")
            print(f"{'Brier Degradation':<20} {test_brier - train_brier:+12.4f}")

            # ====== Adding training diagnostics information to calibration information =======
            if 'train_calibration_metrics' in calibration_info:
                train_diagnostics.update({
                    'train_calibration_metrics': calibration_info['train_calibration_metrics']
                })
            if 'test_calibration_metrics' in calibration_info:
                test_post_diagnostics.update({
                    'test_calibration_metrics': calibration_info['test_calibration_metrics']
                })
            
            print("\n📈 Calibration Assessment:")
            if ce_improvement > 0.01:
                print("   ✅ Significant improvement in training calibration error")
            elif ce_improvement > 0:
                print("   ✓ Slight improvement in training calibration error")
            else:
                print("   ⚠️  No improvement in training calibration error")
            # ======== Test Set Calibration Assessment ========
            if test_ce < 0.05:
                print("   ✅ Excellent test set calibration (CE < 0.05)")
            elif test_ce < 0.10:
                print("   ✓ Good test set calibration (CE < 0.10)")
            else:
                print("   ⚠️  Poor test set calibration (CE ≥ 0.10)")

        # ======== Comprehensive train/test performance comparison ========
        print_train_test_comparison(train_diagnostics, test_post_diagnostics, "COMPREHENSIVE TRAIN/TEST PERFORMANCE COMPARISON")

        print("\n" + "="*80)
        print("                          INTERPRETATIONS")
        print("="*80)
        # ======== ATE interpretations ========
        print("🎯 ATE (Average Treatment Effect):")
        print("   - Population-wide average causal effect")
        print("   - Uses overlap weighting and calibrated propensity scores")
        print(f"   - Estimate: {ate:.6f}")
        # ======== ATT interpretations ========
        print("\n🎪 ATT (Average Treatment Effect on the Treated):")
        print("   - Average causal effect among those who received treatment")
        print("   - Policy-relevant for understanding treatment effectiveness")
        if not np.isnan(att):
            print(f"   - Estimate: {att:.6f}")
        else:
            print("   - Not available (no treated units)")
        # ======== Raw vs TMLE Comparison ========
        print("\n" + "-"*60)
        print("            Raw vs TMLE Comparison (Test Set)")
        print("-"*60)
        print(f"Raw ATE (Pre-update):     {test_pre_diagnostics['raw_ate']:.6f}")
        print(f"TMLE ATE (Post-update):   {ate:.6f}")
        print(f"Adjustment Magnitude:     {abs(ate - test_pre_diagnostics['raw_ate']):.6f}")
        if test_pre_diagnostics['raw_ate'] != 0:
            relative_change = abs(ate - test_pre_diagnostics['raw_ate'])/abs(test_pre_diagnostics['raw_ate'])*100
            print(f"Relative Change:          {relative_change:.2f}%")

        print("\n" + "="*80)
        
        # ======== Effect Size Interpretation ========
        if abs(ate) < 0.01:
            effect_size = "Negligible"
        elif abs(ate) < 0.05:
            effect_size = "Small"
        elif abs(ate) < 0.1:
            effect_size = "Medium"
        else:
            effect_size = "Large"

        direction = "Positive" if ate > 0 else "Negative"
        significance = "Statistically Significant" if p_value < 0.05 else "Not Statistically Significant"

        print(f"📊 Effect Summary: {effect_size} {direction} Treatment Effect, {significance}")
        print("   🎯 Enhanced with calibrated propensity scores for improved reliability")
        print("   📈 Comprehensive train/test performance comparison provided")
        print("   🔍 Overfitting assessment included for model validation")
        print("="*80)


# `` tmle_project``
 

In [None]:
def tmle_project(file_path, test_size=0.3, random_state=42, calibration_method='platt'):
    print("***** Start TMLE Analysis with Calibrated G-Model & Train/Test Comparison *****")
    print("="*80)
    print(f" Using {calibration_method} calibration method")

    with tqdm(total=12, desc="TMLE Progress", bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as main_pbar:

        try:
            # 1. Data Loading and Splitting
            main_pbar.set_description("====== Loading and Splitting Data ======")
            data_splits = data_loading(file_path, test_size, random_state)
            main_pbar.update(1)
            time.sleep(0.1)

            # 2. Data Preprocessing
            main_pbar.set_description("======  Preprocessing Data =======")
            W_train_std, W_test_std, scaler = data_preprocessing(
                data_splits['train']['W'], data_splits['test']['W']
            )
            
            W_A_train_std = pd.concat([W_train_std, pd.DataFrame(data_splits['train']['A'], columns=['A'])], axis=1)
            W_A_test_std = pd.concat([W_test_std, pd.DataFrame(data_splits['test']['A'], columns=['A'])], axis=1)
            main_pbar.update(1)
            time.sleep(0.1)

            # 3. Set up base learners
            main_pbar.set_description("====== Setting Up Base Learners ======")
            base_learners = get_base_learners(W_train_std.shape[1])
            main_pbar.update(1)
            time.sleep(0.1)

            # 4. Fit Q models (on training set)
            main_pbar.set_description("======= Step 1: Fit Outcome Models (Q) on Train Set ======")
            sl = fit_superlearner(W_A_train_std, data_splits['train']['Y'], base_learners, "Outcome Model")
            
            # 5. Predict Q models (on both train and test sets)
            Q_A_train, Q_1_train, Q_0_train = predict_Q_models(sl, W_A_train_std, is_test=False)
            Q_A_test, Q_1_test, Q_0_test = predict_Q_models(sl, W_A_test_std, is_test=True)
            main_pbar.update(1)

            # 6.  Estimate calibrated propensity scores (g) (on both train and test sets)
            main_pbar.set_description("====== Step 2: Estimate Calibrated Propensity Scores (g) ======")
            g_results = estimate_g_with_calibration(
                data_splits['train']['A'], W_train_std, 
                data_splits['test']['A'], W_test_std, 
                base_learners, calibration_method
            )
            main_pbar.update(1)

            # 7. Training set diagnostic checks
            main_pbar.set_description("====== Diagnostic Checks : Training Set ======")
            train_diagnostics = diagnostic_checks(
                data_splits['train']['Y'], data_splits['train']['A'], data_splits['train']['W'], 
                Q_A_train, Q_1_train, Q_0_train, g_results['train']['g_w'], "Training", is_test=False, 
                model_performance=g_results['calibration_info']
            )
            main_pbar.update(1)

            # 8. Pre-update diagnostic checks (testing)
            main_pbar.set_description("====== Diagnostic Checks : Pre-update (Test Set) ======")
            test_pre_diagnostics = diagnostic_checks(
                data_splits['test']['Y'], data_splits['test']['A'], data_splits['test']['W'], 
                Q_A_test, Q_1_test, Q_0_test, g_results['test']['g_w'], "Pre-update", is_test=True, 
                model_performance=g_results['calibration_info']
            )
            main_pbar.update(1)

            # 9. Estimate and update fluctuation parameters (testing)
            main_pbar.set_description("====== Step 3: TMLE Update (Test Set) ======")
            eps = estimate_fluctuation_param(
                data_splits['test']['Y'], Q_A_test, g_results['test']['H_1'], g_results['test']['H_0'], 
                data_splits['test']['A'], g_results['test']['H_overlap']
            )

            # Q function updates (testing)
            Q_A_update_test = update_Q(Q_A_test, g_results['test']['H_overlap'], eps)
            Q_1_update_test = update_Q(Q_1_test, (1 - g_results['test']['g_w']), eps)
            Q_0_update_test = update_Q(Q_0_test, (-g_results['test']['g_w']), eps)
            main_pbar.update(1)

            # 10. Post-update diagnostic checks (testing)
            main_pbar.set_description("====== Diagnostic Checks : Post-update (Test Set) ======")
            test_post_diagnostics = diagnostic_checks(
                data_splits['test']['Y'], data_splits['test']['A'], data_splits['test']['W'], 
                Q_A_update_test, Q_1_update_test, Q_0_update_test, g_results['test']['g_w'], "Post-update", is_test=True,
                model_performance=g_results['calibration_info']
            )
            main_pbar.update(1)

            # 10. Compute final results (testing set)
            main_pbar.set_description("====== Compute Final Results (Test Set) ======")
            ate, se, ci_low, ci_high, p_value, infl_fn, att = compute_tmle(
                data_splits['test']['Y'], data_splits['test']['A'], 
                Q_A_update_test, Q_1_update_test, Q_0_update_test, 
                g_results['test']['H_1'], g_results['test']['H_0'], 
                g_results['test']['overlap_weights'], g_results['test']['H_overlap']
            )
            
            raw_ate = test_pre_diagnostics['raw_ate']
            main_pbar.update(1)
            
            # 11. Print results
            main_pbar.set_description("====== Printing Results ======")
            print_results(ate, se, ci_low, ci_high, p_value, raw_ate, train_diagnostics, test_pre_diagnostics, test_post_diagnostics, att, g_results['calibration_info'])
            main_pbar.update(1)

            # 12. Additional Information
            main_pbar.set_description("====== Final Summary ======")
            print(f"\n📊 Dataset Information:")
            print(f"   Training set size: {len(data_splits['train']['Y'])}")
            print(f"   Test set size: {len(data_splits['test']['Y'])}")
            print(f"   Test size ratio: {test_size*100:.1f}%")
            print(f"   Calibration method: {calibration_method}")
            main_pbar.update(1)
            
            return {
                'ate': ate, 
                'att': att, 
                'se': se, 
                'ci_low': ci_low, 
                'ci_high': ci_high, 
                'p_value': p_value, 
                'raw_ate': raw_ate, 
                'influence_function': infl_fn,
                'train_diagnostics': train_diagnostics,  
                'test_pre_diagnostics': test_pre_diagnostics, 
                'test_post_diagnostics': test_post_diagnostics,
                'overlap_weights': g_results['test']['overlap_weights'], 
                'g_scores': g_results['test']['g_w'],
                'calibration_info': g_results['calibration_info'],
                'train_size': len(data_splits['train']['Y']), 
                'test_size': len(data_splits['test']['Y']),
                'test_ratio': test_size,
                'calibration_method': calibration_method
            }
        
        except Exception as e:
            main_pbar.set_description("<< Analysis Failed >>")
            print(f"An error occurred during the analysis: {str(e)}")
            print("Please check the data format and path")
            return None

### main 

In [None]:
if __name__ == "__main__":
    # calibration method: 'platt' or 'isotonic' 
    calibration_methods = ['platt', 'isotonic']
    
    print("🎯 Available calibration methods:")
    print("   - 'platt': Platt scaling (sigmoid function)")
    print("   - 'isotonic': Isotonic regression (monotonic function)")
    print("\n" + "="*60)

    # using Platt scaling as default
    results = tmle_project(
        '/Users/chendawei/Desktop/MIT TMLE ICU project/yasmeen tmle/tmle_data.csv', 
        test_size=0.3, 
        random_state=42,
        calibration_method='platt'  
    )