In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install shap
!pip install tabpfn

In [None]:
import numpy as np
import pandas as pd

import sklearn
import sklearn.metrics
from sklearn.metrics import (accuracy_score, auc, classification_report, confusion_matrix, f1_score, matthews_corrcoef, precision_recall_curve, precision_score, recall_score, roc_auc_score, roc_curve, brier_score_loss)
from sklearn.utils import resample
from math import sqrt
from scipy import stats as st
from random import randrange

from matplotlib import pyplot
import seaborn as sns

from sklearn.model_selection import StratifiedKFold
from tabpfn import TabPFNClassifier
from sklearn.calibration import CalibratedClassifierCV

import shap
from sklearn.inspection import PartialDependenceDisplay
from sklearn.calibration import calibration_curve

import cv2
from google.colab.patches import cv2_imshow
from PIL import Image

pd.set_option('display.max_rows', None)

# Preparing Data

In [None]:
#Open csv file.

data = pd.read_csv('/content/drive/MyDrive/DMVO-mRS/final_data.csv', index_col = 0)

In [None]:
print(list(data.columns))

In [None]:
#Define outcome of interest.

data.loc[data['mRS at 90 days'] <= 2, 'OUTCOME'] = 0
data.loc[data['mRS at 90 days'] > 2, 'OUTCOME'] = 1
data = data.dropna(subset=['OUTCOME'])

data['OUTCOME'].value_counts(normalize=False, dropna=False)

In [None]:
#Define predictor variables (x) and outcome of interest (y).

outcomes = ['mRS at 90 days', 'OUTCOME']

x = data.drop(outcomes, axis = 1)
y = data['OUTCOME']

In [None]:
#Check data shapes.

print(y.shape)
print(x.shape)

In [None]:
def bootstrap_ci(metric_list, n_bootstraps=1000, alpha=0.05):
    bootstrapped_metrics = []
    for _ in range(n_bootstraps):
        bootstrapped_metric = np.mean(resample(metric_list, replace=True, n_samples=len(metric_list)))
        bootstrapped_metrics.append(bootstrapped_metric)

    lower_bound = np.percentile(bootstrapped_metrics, alpha / 2 * 100)
    upper_bound = np.percentile(bootstrapped_metrics, (1 - alpha / 2) * 100)
    return lower_bound, upper_bound

#TabPFN

No hyperparameter tuning is performed for TabPFN since the paper that introduced TabPFN claims no hyperparameter tuning is needed for it:


*We present TabPFN, a trained Transformer that can do supervised classification for small tabular datasets in less than a second, needs no hyperparameter tuning and is competitive with state-of-the-art classification methods.*

https://doi.org/10.48550/arXiv.2207.01848


In [None]:
# Initialize 5-fold cross-validator.

cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=31)

In [None]:
from sklearn.linear_model import Lasso
from collections import Counter
from sklearn.preprocessing import MinMaxScaler

In [None]:
def select_features_using_lasso(X_train, y_train, alpha=0.01):
    lasso = Lasso(alpha=alpha)
    lasso.fit(X_train, y_train)
    selected_features = X_train.columns[lasso.coef_ != 0]
    return selected_features

feature_counter = Counter()

for train_index, _ in cv.split(x, y):

    X_train_fold, _ = x.iloc[train_index], x.iloc[train_index]
    y_train_fold, _ = y.iloc[train_index], y.iloc[train_index]

    # Min-max scaling the training fold.
    scaler = MinMaxScaler()
    X_train_fold_scaled = scaler.fit_transform(X_train_fold)
    X_train_fold_scaled = pd.DataFrame(X_train_fold_scaled, columns=X_train_fold.columns, index=X_train_fold.index)

    selected_features = select_features_using_lasso(X_train_fold_scaled, y_train_fold)
    feature_counter.update(selected_features)

# Select features that were chosen in a majority of the folds.
selected_features_final = [feat for feat, count in feature_counter.items() if count > (cv.get_n_splits() / 2)]

# Save the selected features to a txt file
with open("/content/drive/MyDrive/DMVO-mRS/selected_features.txt", "w") as file:
    for feature in selected_features_final:
        file.write(feature + "\n")

print(selected_features_final)

In [None]:
# Create empty lists to store metrics for each fold.

precision_list = []
recall_list = []
f1_list = []
acc_list = []
mcc_list = []
auroc_list = []
auprc_list = []
tpr_list = []
prc_list = []
brier_list = []
true_probs_list = []
pred_probs_list = []

aggregate_cm = np.zeros((2,2))
base_fpr = np.linspace(0, 1, 101)

In [None]:
# Perform cross-validation with k-fold CV using the best parameters.

for train_index, test_index in cv.split(x, y):
    X_train_fold, X_test_fold = x[selected_features_final].iloc[train_index], x[selected_features_final].iloc[test_index]
    y_train_fold, y_test_fold = y.iloc[train_index], y.iloc[test_index]

    tabpfn = TabPFNClassifier(device='cuda', N_ensemble_configurations=8)
    tabpfn.fit(X_train_fold, y_train_fold)

    calibrated = CalibratedClassifierCV(tabpfn, method='sigmoid', cv='prefit')
    calibrated.fit(X_train_fold, y_train_fold)

    preds = calibrated.predict(X_test_fold.values)
    probs = calibrated.predict_proba(X_test_fold.values)
    probs = probs[:, 1]

    true_probs_list.extend(y_test_fold)
    pred_probs_list.extend(probs)

    # Calculate performance metrics.

    precision_list.append(precision_score(y_test_fold, preds))
    recall_list.append(recall_score(y_test_fold, preds))
    f1_list.append(f1_score(y_test_fold, preds))
    acc_list.append(accuracy_score(y_test_fold, preds))
    mcc_list.append(matthews_corrcoef(y_test_fold, preds))
    auroc_list.append(roc_auc_score(y_test_fold, probs))
    prc_p, prc_r, _ = precision_recall_curve(y_test_fold, probs)
    auprc_list.append(auc(prc_r, prc_p))
    brier_list.append(brier_score_loss(y_test_fold, probs))

    # Compute and add the confusion matrix of the fold.
    cm = confusion_matrix(y_test_fold, preds)
    aggregate_cm += cm

    # Calculate ROC and PR curves.

    fpr, tpr, _ = roc_curve(y_test_fold, probs)
    tpr = np.interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tpr_list.append(tpr)

    precision, recall, _ = precision_recall_curve(y_test_fold, probs)
    prc = np.interp(base_fpr, recall[::-1], precision[::-1])
    prc_list.append(prc)

In [None]:
# Calculate the mean for each metric.

precision_mean = np.mean(precision_list)
recall_mean = np.mean(recall_list)
f1_mean = np.mean(f1_list)
acc_mean = np.mean(acc_list)
mcc_mean = np.mean(mcc_list)
auroc_mean = np.mean(auroc_list)
auprc_mean = np.mean(auprc_list)
brier_mean = np.mean(brier_list)

In [None]:
# Calculate the confidence intervals for each metric.

metrics = {
    'Precision': precision_list,
    'Recall': recall_list,
    'F1 Score': f1_list,
    'Accuracy': acc_list,
    'MCC': mcc_list,
    'AUROC': auroc_list,
    'AUPRC': auprc_list,
    'Brier Score': brier_list
}

result_strings = {}

for metric_name, metric_list in metrics.items():
    mean = round(np.mean(metric_list), 3)
    lower_ci, upper_ci = bootstrap_ci(metric_list)
    result_str = f"{metric_name}: {mean} ({lower_ci:.3f}, {upper_ci:.3f})"
    result_strings[metric_name] = result_str
    print(result_str)

precision_str = result_strings['Precision']
recall_str = result_strings['Recall']
f1_str = result_strings['F1 Score']
acc_str = result_strings['Accuracy']
mcc_str = result_strings['MCC']
auroc_str = result_strings['AUROC']
auprc_str = result_strings['AUPRC']
brier_str = result_strings['Brier Score']

with open('/content/drive/MyDrive/DMVO-mRS/results.txt', 'w') as file:
    file.write(precision_str + '\n')
    file.write(recall_str + '\n')
    file.write(f1_str + '\n')
    file.write(acc_str + '\n')
    file.write(mcc_str + '\n')
    file.write(auroc_str + '\n')
    file.write(auprc_str + '\n')
    file.write(brier_str + '\n')

In [None]:
# Calculate ROC, PR and calibration curves.

tpr_list = np.array(tpr_list)
mean_tprs = tpr_list.mean(axis=0)
std_tprs = tpr_list.std(axis=0)

prc_list = np.array(prc_list)
mean_prcs = prc_list.mean(axis=0)
std_prcs = prc_list.std(axis=0)

fraction_of_positives, mean_predicted_value = calibration_curve(true_probs_list, pred_probs_list, n_bins=7, strategy='quantile')

# Plotting

In [None]:
# Plot ROC curve.

f = pyplot.figure()
f.set_figwidth(10)
f.set_figheight(10)

auroc_label = auroc_str

pyplot.plot(base_fpr, mean_tprs, label=auroc_label, color='r', linewidth = 3.5, alpha = 0.75)

pyplot.plot([0, 1], [0, 1], linestyle = '--', linewidth=2)

pyplot.title('A', x = -0.075, y = 1.005, fontsize = 75, pad = 20)
pyplot.xlabel('False Positive Rate', fontsize = 22, fontweight = 'heavy', labelpad = 16)
pyplot.ylabel('True Positive Rate', fontsize = 22, fontweight = 'heavy', labelpad = 16)
pyplot.tick_params(axis="y",direction="out", labelsize = 16)
pyplot.tick_params(axis="x",direction="out", labelsize = 16)

leg = pyplot.legend(loc = 'lower right', fontsize = 18)

pyplot.savefig('/content/drive/MyDrive/DMVO-mRS/roc.jpg', dpi=300)
pyplot.show()

In [None]:
# Plot PR curve.

f = pyplot.figure()
f.set_figwidth(10)
f.set_figheight(10)

auprc_label = auprc_str

pyplot.plot(base_fpr, mean_prcs, label=auprc_label, color = 'b', linewidth = 3.5, alpha = 0.75)

pyplot.title('B', x = -0.075, y = 1.005, fontsize = 75, pad = 20)
pyplot.xlabel('Recall', fontsize = 22, fontweight = 'heavy', labelpad = 16)
pyplot.ylabel('Precision', fontsize = 22, fontweight = 'heavy', labelpad = 16)
pyplot.tick_params(axis="y",direction="out", labelsize = 16)
pyplot.tick_params(axis="x",direction="out", labelsize = 16)

leg = pyplot.legend(loc = 'lower left', fontsize = 18)

pyplot.savefig('/content/drive/MyDrive/DMVO-mRS/prc.jpg', dpi=300)
pyplot.show()

In [None]:
# Plot Calibration curve.

f = pyplot.figure()
f.set_figwidth(10)
f.set_figheight(10)

brier_label = brier_str

pyplot.plot(mean_predicted_value, fraction_of_positives, 's-', label=brier_label, color='g', linewidth=3.5, alpha=0.75)
pyplot.plot([0, 1], [0, 1], linestyle='--', linewidth=2)

pyplot.title('C', x=-0.075, y=1.005, fontsize=75, pad=20)
pyplot.xlabel('Mean Predicted Probability', fontsize=22, fontweight='heavy', labelpad=16)
pyplot.ylabel('Fraction of Positives', fontsize=22, fontweight='heavy', labelpad=16)
pyplot.tick_params(axis="y", direction="out", labelsize=16)
pyplot.tick_params(axis="x", direction="out", labelsize=16)

leg = pyplot.legend(loc='lower right', fontsize=18)

pyplot.savefig('/content/drive/MyDrive/DMVO-mRS/calibration.jpg', dpi=300)
pyplot.show()

In [None]:
# Plot the aggregate confusion matrix
f = pyplot.figure(figsize=(10, 10))
aggregate_cm = aggregate_cm.astype(int)
sns.heatmap(aggregate_cm, annot=True, fmt='d', cmap='Blues', cbar=False, annot_kws={"size": 16}, linewidths=1, linecolor='black')

labels = ['mRS 0-2', 'mRS > 2']
pyplot.xticks([0.5, 1.5], labels, fontsize=16, fontweight='heavy')
pyplot.yticks([0.5, 1.5], labels, fontsize=16, fontweight='heavy', va='center')

pyplot.xlabel('Predicted', fontsize=22, fontweight='heavy', labelpad=16)
pyplot.ylabel('Truth', fontsize=22, fontweight='heavy', labelpad=16)

pyplot.tick_params(axis="y", direction="out", pad=10)
pyplot.tick_params(axis="x", direction="out", pad=10)
pyplot.title('D', x=-0.075, y=1.005, fontsize=75, pad=20)

pyplot.savefig('/content/drive/MyDrive/DMVO-mRS/aggregate_cm.jpg', dpi=300)
pyplot.show()

In [None]:
roc = cv2.imread('/content/drive/MyDrive/DMVO-mRS/roc.jpg')
prc = cv2.imread('/content/drive/MyDrive/DMVO-mRS/prc.jpg')
cal = cv2.imread('/content/drive/MyDrive/DMVO-mRS/calibration.jpg')
cm = cv2.imread('/content/drive/MyDrive/DMVO-mRS/aggregate_cm.jpg')

fig1 = cv2.hconcat([roc, prc])
fig2 = cv2.hconcat([cal, cm])

fig = cv2.vconcat([fig1, fig2])

cv2_imshow(fig)

cv2.imwrite('/content/drive/MyDrive/DMVO-mRS/fig.jpg', fig, [cv2.IMWRITE_JPEG_QUALITY, 100])

# SHAP

In [None]:
import textwrap
def wrap_labels(ax, width, break_long_words=False):
    labels = []
    for label in ax.get_yticklabels():
        text = label.get_text()
        labels.append(textwrap.fill(text, width=width,
                                    break_long_words=break_long_words))
    ax.set_yticklabels(labels, rotation=0)

In [None]:
#Calculate SHAP values.

explainer = shap.Explainer(calibrated.predict, x[selected_features_final])
shap_values = explainer(x[selected_features_final])

In [None]:
#Plot SHAP bar plot.

shap.plots.bar(shap_values, max_display = 25, show=False)

fig = pyplot.gcf()
ax = pyplot.gca()
fig.set_figheight(10)
fig.set_figwidth(6)

pyplot.xlabel("Mean |SHAP Value|", fontsize =12, fontweight = 'heavy', labelpad = 8)
pyplot.tick_params(axis="y",direction="out", labelsize = 12)
pyplot.tick_params(axis="x",direction="out", labelsize = 12)

#wrap_labels(ax, 30)
ax.figure

pyplot.savefig('/content/drive/MyDrive/DMVO-mRS/shap.jpg', dpi=300, bbox_inches='tight')

#Partial Dependency Plot

In [None]:
column_filters = {
    'Age': (18, 90),
}

for column, (min_val, max_val) in column_filters.items():
    if column in x.columns:
        x = x[(x[column] > min_val) & (x[column] < max_val)]

In [None]:
pyplot.rcParams["figure.figsize"] = (15, 15)
pyplot.rcParams["figure.dpi"] = 300
pyplot.rcParams['axes.labelweight'] = 'bold'
pyplot.rcParams['axes.labelsize'] = 12
pyplot.rcParams['axes.labelpad'] = 6
pyplot.rcParams['font.weight'] = 'normal'
pyplot.rcParams['lines.linewidth'] = 2.5
pyplot.rcParams['xtick.labelsize'] = 8
pyplot.rcParams['ytick.labelsize'] = 8

In [None]:
feature_names = x[selected_features_final].columns
categorical_features = ['Sex', 'Race', 'Initial Hospital', 'Antiplatelet Use', 'Diuretic Use', 'Current or Former Smoker', 'Current Alcohol Use', 'Hypertension', 'Dyslipidemia', 'Diabetes', 'Heart Disease', 'Atrial Fibrillation', 'History of Malignancy', 'Prior Stroke or TIA', 'HIV', 'HCV', 'Chronic Kidney Disease', 'Sleep Apnea', 'PVD', 'DVT or PE', 'Obesity', 'Admission LAMS', 'Premorbid mRS', 'Stroke Etiology', 'Occlusion Laterality', 'Occlusion Site', 'mTICI', 'DSA Collaterals', 'Single Phase CTA Collateral Score', 'Dynamic CTP mCTA Collateral Score', 'COVES Score', 'Clot Burden Score', 'Baseline NCCT ASPECTS', 'Hyperdense MCA', 'Hemorrhagic Transformation', 'Type of Thrombectomy', 'Number of Passes', 'Type of Anesthesia', 'IV TPA', 'Mechanical Thrombectomy']

In [None]:
#Calculate the mean absolute SHAP values for each feature.
mean_abs_shap_values = np.mean(np.abs(shap_values.values), axis=0)

#Create a DataFrame to map feature names to their mean absolute SHAP values.
shap_summary = pd.DataFrame(list(zip(feature_names, mean_abs_shap_values)), columns=['Feature', 'Mean SHAP'])

#Sort the DataFrame by 'Mean SHAP' in descending order.
shap_summary_sorted = shap_summary.sort_values('Mean SHAP', ascending=False)

#Get the names of the features.
features = shap_summary_sorted['Feature'].tolist()
features = features[:9]

categorical_features = [item for item in features if item in categorical_features]

PartialDependenceDisplay.from_estimator(calibrated, x[selected_features_final], features, categorical_features = categorical_features)
pyplot.savefig('/content/drive/MyDrive/DMVO-mRS/pdp.png', dpi=300, bbox_inches='tight')
pyplot.show()