In [None]:
import os

# Change directory to the specified path
os.chdir('/rds/general/user/sw3720/home/codes/Python/Knowledge_Distillation')
print(f"Current working directory: {os.getcwd()}")

## Import Libraries

In [9]:
import joblib
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_curve, auc,
    precision_recall_curve, average_precision_score,
    accuracy_score, recall_score, confusion_matrix
)
from sklearn.model_selection import KFold
from sklearn.calibration import CalibratedClassifierCV
from tqdm import tqdm
import xgboost as xgb

## Parameters and Model Configuration

In [4]:
XGB_DEFAULT_PARAMS = {
    'objective': 'binary:logistic',
    'eval_metric': 'logloss',
    'booster': 'gbtree',
    'reg_alpha': 0.05,
    'reg_lambda': 0.05,
    'max_depth': 8,
    'learning_rate': 0.01,
    'subsample': 0.8
}
CV_FOLDS = 5
BOOTSTRAP_ITERATIONS = 1000
EARLY_STOPPING_NORMAL = 1000
EARLY_STOPPING_KD = 500
DEFAULT_THRESHOLD = 0.5

## Utility Functions

In [27]:
def get_specificity(y_true, y_pred):
    """Compute specificity: TN / (TN + FP)."""
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    return tn / (tn + fp) if (tn + fp) > 0 else np.nan


def bootstrap_confidence_interval(y_true, y_scores, metric='roc', n_iters=BOOTSTRAP_ITERATIONS, alpha=0.05, random_state=42):
    """
    Compute bootstrap confidence intervals for ROC AUC or AUPRC.
    metric: 'roc' or 'pr'. Returns (lower_ci, upper_ci).
    """
    rng = np.random.RandomState(random_state)
    scores = []
    n = len(y_true)

    for _ in range(n_iters):
        idx = rng.randint(0, n, n)
        if len(np.unique(y_true[idx])) < 2:
            continue
        if metric == 'roc':
            fpr, tpr, _ = roc_curve(y_true[idx], y_scores[idx])
            scores.append(auc(fpr, tpr))
        else:
            prec, rec, _ = precision_recall_curve(y_true[idx], y_scores[idx])
            scores.append(auc(rec, prec))

    lower = np.percentile(scores, 100 * (alpha / 2))
    upper = np.percentile(scores, 100 * (1 - alpha / 2))
    return lower, upper


def compute_classification_metrics(y_true, y_scores, threshold=DEFAULT_THRESHOLD):
    """
    Compute key classification metrics and return as a dict.
    """
    y_pred = (y_scores > threshold).astype(int)
    metrics = {
        'precision': average_precision_score(y_true, y_scores),
        'accuracy': accuracy_score(y_true, y_pred),
        'recall': recall_score(y_true, y_pred),
        'specificity': get_specificity(y_true, y_pred)
    }

    # ROC
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    metrics['roc_auc'] = roc_auc

    # AUPRC
    prec, rec, _ = precision_recall_curve(y_true, y_scores)
    auprc = auc(rec, prec)
    metrics['auprc'] = auprc

    return metrics, (fpr, tpr)

# ------------------------------------
# Custom Loss for Knowledge Distillation
# ------------------------------------
def distillation_gradient_hessian(preds, dtrain, alpha=0.4, temperature=1.2):
    """
    Custom objective: combined gradient/hessian for knowledge distillation.
    dtrain.weight contains teacher probabilities.
    """
    labels = dtrain.get_label()
    teacher_probs = dtrain.get_weight()

    # Student probabilities
    preds_prob = 1 / (1 + np.exp(-preds))
    # Teacher probabilities (scaled)
    teacher_logits = np.log(teacher_probs / (1 - teacher_probs))
    scaled_teacher_logits = teacher_logits / temperature
    teacher_scaled_prob = 1 / (1 + np.exp(-scaled_teacher_logits))

    # Gradients
    grad_ce = preds_prob - labels
    grad_kl = -teacher_scaled_prob * (1 - preds_prob)
    grad = alpha * grad_ce + (1 - alpha) * temperature**2 * grad_kl

    # Hessians
    hess_ce = preds_prob * (1 - preds_prob)
    hess_kl = teacher_scaled_prob * preds_prob * (1 - preds_prob) * (1 - 2 * preds_prob)
    hess = alpha * hess_ce + (1 - alpha) * temperature**2 * hess_kl

    return grad, hess

# ------------------------------------
# Model Training and Evaluation
# ------------------------------------

def train_xgb_model(
    X_train, y_train,
    params=XGB_DEFAULT_PARAMS,
    early_stopping_rounds=EARLY_STOPPING_NORMAL,
    kd_enabled=False,
    soft_labels=None,
    verbose=False
):
    """
    Train an XGBoost model (with optional KD) on provided data.
    soft_labels required if kd_enabled is True.
    """
    dtrain = xgb.DMatrix(X_train, label=y_train, weight=soft_labels if soft_labels is not None else None)
    evals = [(dtrain, 'train')]

    if kd_enabled:
        obj = distillation_gradient_hessian
        esr = EARLY_STOPPING_KD
    else:
        obj = None
        esr = early_stopping_rounds

    model = xgb.train(
        params,
        dtrain,
        num_boost_round=10000,
        obj=obj,
        evals=evals,
        verbose_eval=verbose,
        early_stopping_rounds=esr
    )
   
    return model


def cross_validate_xgb(
    features, labels,
    params=XGB_DEFAULT_PARAMS,
    n_splits=CV_FOLDS,
    kd_enabled=False,
    soft_labels=None
):
    """
    Perform K-fold CV for training and evaluation.
    Returns results dict with metrics, ROC data, mean ROC, CI, and last model.
    """
    X = np.asarray(features)
    y = np.asarray(labels)
    kf = KFold(n_splits=n_splits, shuffle=False)

    mean_fpr = np.linspace(0, 1, 100)
    tprs = []
    roc_data = []
    all_metrics = []
    roc_aucs = []
    pr_aucs = []
    roc_cis = []
    pr_cis = []

    for train_idx, test_idx in kf.split(X):
        X_tr, X_te = X[train_idx], X[test_idx]
        y_tr, y_te = y[train_idx], y[test_idx]
        soft_tr = soft_labels[train_idx] if soft_labels is not None else None

        model = train_xgb_model(
            X_tr, y_tr,
            params=params,
            kd_enabled=kd_enabled,
            soft_labels=soft_tr
        )

        dtest = xgb.DMatrix(X_te, label=y_te)
        y_prob = model.predict(dtest)

        metrics, (fpr, tpr) = compute_classification_metrics(y_te, y_prob)
        all_metrics.append(metrics)
        roc_data.append((fpr, tpr, metrics['roc_auc']))
        roc_aucs.append(metrics['roc_auc'])
        pr_aucs.append(metrics['auprc'])
        roc_cis.append(bootstrap_confidence_interval(y_te, y_prob, 'roc'))
        pr_cis.append(bootstrap_confidence_interval(y_te, y_prob, 'pr'))

        # Interpolate for mean ROC
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)

    # Aggregate
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_metrics = {k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0]}
    ci = {
        'roc_auc': (
            np.mean([low for low, high in roc_cis]),
            np.mean([high for low, high in roc_cis])
        ),
        'pr_auc': (
            np.mean([low for low, high in pr_cis]),
            np.mean([high for low, high in pr_cis])
        )
    }

    return {
        'metrics': mean_metrics,
        'roc_data': roc_data,
        'mean_roc': (mean_fpr, mean_tpr),
        'ci': ci,
        'model': model
    }

# ------------------------------------
# Inference
# ------------------------------------
def inference_xgb(test_df, test_labels, model, feature_col='patientunitstay'):
    """
    Perform inference and compute metrics for new dataset.
    Expects a DataFrame test_df including 'patientunitstay'.
    Returns dict with metrics, ROC data, mean ROC, CI, and model.
    """
    X_test = test_df.drop(columns=[feature_col])
    y_test = np.asarray(test_labels)

    if isinstance(model, xgb.Booster):
        y_prob = model.predict(xgb.DMatrix(X_test))
    else:
        y_prob = model.predict_proba(X_test)[:, 1]

    metrics, (fpr, tpr) = compute_classification_metrics(y_test, y_prob)
    roc_ci = bootstrap_confidence_interval(y_test, y_prob, 'roc')
    pr_ci = bootstrap_confidence_interval(y_test, y_prob, 'pr')

    # mean ROC as a single fold
    mean_fpr = np.linspace(0, 1, 100)
    mean_tpr = np.interp(mean_fpr, fpr, tpr)
    mean_tpr[0], mean_tpr[-1] = 0.0, 1.0

    return {
        'metrics': metrics,
        'roc_data': [(fpr, tpr, metrics['roc_auc'])],
        'mean_roc': (mean_fpr, mean_tpr),
        'ci': {'roc_auc': roc_ci, 'pr_auc': pr_ci},
        'model': model
    }

# ------------------------------------
# Plotting
# ------------------------------------
def plot_roc_curve(roc_data, mean_roc, ci):
    """
    Plot ROC curves for each fold and the mean ROC with confidence intervals.
    """
    plt.figure(figsize=(6, 5))
    for idx, (fpr, tpr, auc_val) in enumerate(roc_data, start=1):
        plt.plot(fpr, tpr, alpha=0.3, lw=1,
                 label=f'Fold {idx} ROC (AUC={auc_val:.2f})')

    mean_fpr, mean_tpr = mean_roc
    mean_auc = auc(mean_fpr, mean_tpr)
    lower, upper = ci['roc_auc']
    plt.plot(mean_fpr, mean_tpr, lw=2,
             label=f'Mean ROC (AUC={mean_auc:.2f} ± [{lower:.2f}, {upper:.2f}])')

    plt.plot([0, 1], [0, 1], linestyle='--', lw=2,
             label='Random Guess')
    plt.xlim(0, 1)
    plt.ylim(0, 1.05)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc='lower right')
    plt.show()


## Model Trained on MIMIC-III

In [None]:
# Load MIMIC-III data and labels
mimic_df = joblib.load('Variables/mimic_iii_data.pkl')
mimic_labels = joblib.load('Variables/mimic_iii_label.pkl')

# Full feature model
full_results = cross_validate_xgb(mimic_df, mimic_labels)
plot_roc_curve(full_results['roc_data'], full_results['mean_roc'], full_results['ci'])
print("Full Model Metrics:")
for metric, value in full_results['metrics'].items():
    print(f"{metric}: {value:.4f}")

## Model Trained on MIMIC-III with Limited Features Available in eICU

In [None]:
eicu_df = joblib.load('Variables/eicu_data.pkl')
eicu_labels = joblib.load('Variables/eicu_label.pkl')
# Limited feature CV
common_features = eicu_df.drop(columns=['patientunitstay']).columns.tolist()
mimic_limited_df = mimic_df[common_features]
limited_results = cross_validate_xgb(mimic_limited_df, mimic_labels)
small_model = limited_results['model']

# Inference on eICU
inf_results = inference_xgb(eicu_df, eicu_labels, small_model)
print("\nSmall Model Inference Metrics:")
for metric, value in inf_results['metrics'].items():
    print(f"{metric}: {value:.4f}")

## Knowledge Distillation

In [None]:
booster_df = xgb.DMatrix(mimic_df, label=mimic_labels)
teacher_scores = full_results['model'].predict(booster_df)
distilled_results = cross_validate_xgb(
    mimic_limited_df, mimic_labels,
    kd_enabled=True,
    soft_labels=teacher_scores
)
distilled_model = distilled_results['model']

## Model Calibration

In [None]:
distilled_model.save_model('Models/mimic_distilled.model')
clf = xgb.XGBClassifier()
clf.load_model('Models/mimic_distilled.model')
calibrated_clf = CalibratedClassifierCV(
    estimator=clf, method='isotonic', cv='prefit'
)
calibrated_clf.fit(
    eicu_df.drop(columns=['patientunitstay']),
    eicu_labels
)
cal_inf_results = inference_xgb(eicu_df, eicu_labels, calibrated_clf)
print("\nCalibrated Model Inference Metrics:")
for metric, value in cal_inf_results['metrics'].items():
    print(f"{metric}: {value:.4f}")