In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os

from tabpfn import TabPFNClassifier
import joblib
import shap

from sklearn.metrics import (accuracy_score, auc, classification_report, confusion_matrix, f1_score, matthews_corrcoef, precision_recall_curve, precision_score, recall_score, roc_auc_score, roc_curve, brier_score_loss)
from sklearn.utils import resample
from sklearn.model_selection import StratifiedKFold
from sklearn.calibration import CalibratedClassifierCV
from sklearn.calibration import calibration_curve

## Prepare Data

In [None]:
#Load df
df = pd.read_csv('mevo_bern.csv')

##Split df +/- EVT, split y from x, drop columns from x
col_to_drop = ['NIHSS 24h', '3M mRS', 'IAT_simplified', 'nihss_24h_target']

#No EVT
x_no_evt_all = df[df['IAT_simplified']==0].reset_index(drop=True) #includes all variables
y_no_evt = x_no_evt_all['nihss_24h_target'] #y
x_no_evt_pred = x_no_evt_all.drop(columns=col_to_drop) #includes only predictors for TabPFN

#EVT
x_evt_all = df[df['IAT_simplified']==1].reset_index(drop=True)
y_evt = x_evt_all['nihss_24h_target']
x_evt_pred= x_evt_all.drop(columns=col_to_drop)

## Validation of TabPFN Model in Bernese local stroke registry cohort: Cross-validation (without EVT)

In [None]:
#Set up cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=False)

In [None]:
#Create empty lists to store metrics for each fold
precision_list = []
recall_list = []
f1_list = []
acc_list = []
mcc_list = []
auroc_list = []
auprc_list = []
tpr_list = []
prc_list = []
brier_list = []
true_probs_list = []
pred_probs_list = []

aggregate_cm = np.zeros((2,2))
base_fpr = np.linspace(0, 1, 101)

indexes = []
preds_list = []

In [None]:
#Perform 5-fold cross-validation
for train_index, test_index in cv.split(x_no_evt_pred, y_no_evt):

    X_train_fold, X_test_fold = x_no_evt_pred.iloc[train_index], x_no_evt_pred.iloc[test_index]
    y_train_fold, y_test_fold = y_no_evt.iloc[train_index], y_no_evt.iloc[test_index]

    indexes.append(test_index)
    
    tabpfn = TabPFNClassifier(device='cuda')
    tabpfn.fit(X_train_fold, y_train_fold)

    calibrated = CalibratedClassifierCV(tabpfn, method='sigmoid', cv='prefit')
    calibrated.fit(X_train_fold, y_train_fold)

    probs = calibrated.predict_proba(X_test_fold.values)
    probs = probs[:, 1]

    preds = calibrated.predict(X_test_fold.values)
    preds_list.append(preds)

    true_probs_list.extend(y_test_fold)
    pred_probs_list.extend(probs)

    #Calculate performance metrics
    precision_list.append(precision_score(y_test_fold, preds))
    recall_list.append(recall_score(y_test_fold, preds))
    f1_list.append(f1_score(y_test_fold, preds))
    acc_list.append(accuracy_score(y_test_fold, preds))
    mcc_list.append(matthews_corrcoef(y_test_fold, preds))
    auroc_list.append(roc_auc_score(y_test_fold, probs))
    prc_p, prc_r, _ = precision_recall_curve(y_test_fold, probs)
    auprc_list.append(auc(prc_r, prc_p))
    brier_list.append(brier_score_loss(y_test_fold, probs))

    #Compute and add the confusion matrix of the fold
    cm = confusion_matrix(y_test_fold, preds)
    aggregate_cm += cm
    print(aggregate_cm)

    #Calculate ROC and PR curves
    fpr, tpr, _ = roc_curve(y_test_fold, probs)
    tpr = np.interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tpr_list.append(tpr)

    precision, recall, _ = precision_recall_curve(y_test_fold, probs)
    prc = np.interp(base_fpr, recall[::-1], precision[::-1])
    prc_list.append(prc)

In [None]:
#Flatten nested lists
indexes_flat = [num for sublist in indexes for num in sublist]
preds_list_flat = [num for sublist in preds_list for num in sublist]

#Sort so indexes match x_no_evt_all
preds_no_evt = pd.DataFrame({'index':indexes_flat, 'predictions':preds_list_flat})
preds_no_evt = preds_no_evt.sort_values(by='index')

#Add predictions to x_no_evt_all
x_no_evt_all['predicted'] = preds_no_evt['predictions']

In [None]:
#Calculate the mean for each metric.
precision_mean = np.mean(precision_list)
recall_mean = np.mean(recall_list)
f1_mean = np.mean(f1_list)
acc_mean = np.mean(acc_list)
mcc_mean = np.mean(mcc_list)
auroc_mean = np.mean(auroc_list)
auprc_mean = np.mean(auprc_list)
brier_mean = np.mean(brier_list)

In [None]:
#Calculate the confidence intervals for each metric.
def bootstrap_ci(metric_list, n_bootstraps=1000, alpha=0.05):
    bootstrapped_metrics = []
    for _ in range(n_bootstraps):
        bootstrapped_metric = np.mean(resample(metric_list, replace=True, n_samples=len(metric_list)))
        bootstrapped_metrics.append(bootstrapped_metric)

    lower_bound = np.percentile(bootstrapped_metrics, alpha / 2 * 100)
    upper_bound = np.percentile(bootstrapped_metrics, (1 - alpha / 2) * 100)
    return lower_bound, upper_bound


metrics = {
    'Precision': precision_list,
    'Recall': recall_list,
    'F1 Score': f1_list,
    'Accuracy': acc_list,
    'MCC': mcc_list,
    'AUROC': auroc_list,
    'AUPRC': auprc_list,
    'Brier Score': brier_list
}

result_strings = {}

for metric_name, metric_list in metrics.items():
    mean = round(np.mean(metric_list), 3)
    lower_ci, upper_ci = bootstrap_ci(metric_list)
    result_str = f"{metric_name}: {mean:.2f} ({lower_ci:.2f}, {upper_ci:.2f})"
    result_strings[metric_name] = result_str
    print(result_str)

precision_str = result_strings['Precision']
recall_str = result_strings['Recall']
f1_str = result_strings['F1 Score']
acc_str = result_strings['Accuracy']
mcc_str = result_strings['MCC']
auroc_str = result_strings['AUROC']
auprc_str = result_strings['AUPRC']
brier_str = result_strings['Brier Score']

In [None]:
#Calculate ROC, PR and calibration curves.
tpr_list = np.array(tpr_list)
mean_tprs = tpr_list.mean(axis=0)
std_tprs = tpr_list.std(axis=0)

prc_list = np.array(prc_list)
mean_prcs = prc_list.mean(axis=0)
std_prcs = prc_list.std(axis=0)
import warnings
warnings.filterwarnings('ignore')
fraction_of_positives, mean_predicted_value = calibration_curve(true_probs_list, pred_probs_list, n_bins=7, strategy='quantile')

## Train TabPFN with all BMT alone patients from the Bernese local stroke registry

In [None]:
# Set up TabPFN
tabpfn = TabPFNClassifier(device='cpu')

#### Train and save TabPFN
# Fit TabPFN on all patients without EVT
tabpfn.fit(x_no_evt_pred, y_no_evt)

# Calibrate TabPFN
calibrated = CalibratedClassifierCV(tabpfn, method='sigmoid', cv='prefit')
calibrated.fit(x_no_evt_pred, y_no_evt)

### Save model to export and validate in/apply to DISTAL cohort

In [None]:
joblib.dump(calibrated, 'models/TabPFN_trained_Bern.pkl')

#### SHAP Bar Plots

In [None]:
#Calculate SHAP values.
explainer = shap.Explainer(calibrated.predict, x_no_evt_pred)
shap_values = explainer(x_no_evt_pred)

In [None]:
#Plot SHAP bar plot.
shap.plots.bar(shap_values, max_display=25, show=False)
plt.savefig("/plots/shap_bar.png", bbox_inches="tight", dpi=300)