In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, recall_score, matthews_corrcoef, roc_auc_score, confusion_matrix
from tensorflow.keras.layers import Input, Dense, Conv1D, MaxPooling1D, Flatten, Dropout, BatchNormalization
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping

# Load the dataset
dataset = pd.read_excel('Final_non_redundant_sequences.xlsx', na_filter=False)
X_data_name = 'whole_sample_dataset_esm2_t33_650M_UR50D_unified_1280_dimension.csv'
X_data = pd.read_csv(X_data_name, header=0, index_col=0, delimiter=',')
X = np.array(X_data)
y = np.array(dataset['label'])

# Split dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)

# Normalize the data
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

def build_model(input_shape):
    input = Input(input_shape)
    x = Conv1D(64, 5, strides=1, padding='same')(input)
    x = BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = MaxPooling1D(2, padding='same')(x)
    x = Dropout(0.25)(x)
    
    x = Conv1D(128, 5, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = MaxPooling1D(2, padding='same')(x)
    x = Dropout(0.25)(x)
    
    x = Flatten()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=input, outputs=x)
    return model

def step_decay(epoch):
    initial_lr = 0.01
    drop = 0.5
    epochs_drop = 10
    lr = initial_lr * np.power(drop, np.floor((1 + epoch) / epochs_drop))
    return lr

def train_model(X_train, y_train, X_test, y_test):
    input_shape = (1280, 1)
    model = build_model(input_shape)
    
    # Optimizer
    adam = tf.keras.optimizers.Adam(learning_rate=0.001)
    
    model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
    
    lrate = LearningRateScheduler(step_decay)
    early_stop = EarlyStopping(monitor='val_accuracy', patience=20, verbose=1, restore_best_weights=True)
    mc = ModelCheckpoint('best_model_1280.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
    callbacks_list = [lrate, early_stop, mc]
    
    class_weight = {0: 1, 1: 2}  # Adjust the weights as needed
    
    model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=100, callbacks=callbacks_list, batch_size=32, class_weight=class_weight)
    
    return model

# Train the model
trained_model = train_model(X_train, y_train, X_test, y_test)

# Load the best model
saved_model = load_model('best_model_1280.h5')

# Function to optimize threshold based on MCC
def optimize_threshold(y_true, y_pred_probas):
    thresholds = np.arange(0.1, 1.0, 0.05)
    best_mcc = -1
    best_threshold = 0.5
    
    for threshold in thresholds:
        y_pred = (y_pred_probas > threshold).astype(int)
        mcc = matthews_corrcoef(y_true, y_pred)
        
        if mcc > best_mcc:
            best_mcc = mcc
            best_threshold = threshold
    
    return best_threshold, best_mcc




Epoch 1/100


Epoch 1: val_accuracy improved from -inf to 0.25365, saving model to best_model_1280.h5


  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.25365
Epoch 3/100
Epoch 3: val_accuracy improved from 0.25365 to 0.25547, saving model to best_model_1280.h5
Epoch 4/100
Epoch 4: val_accuracy improved from 0.25547 to 0.71533, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.71533 to 0.90146, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy did not improve from 0.90146
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.90146
Epoch 8/100
Epoch 8: val_accuracy improved from 0.90146 to 0.90237, saving model to best_model_1280.h5
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90237
Epoch 10/100
Epoch 10: val_accuracy improved from 0.90237 to 0.91332, saving model to best_model_1280.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.91332
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91332
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91332
Epoch 14/100
Epoch 14: val_accuracy 

In [2]:
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score, confusion_matrix
import math

k = 10
kf = KFold(n_splits=k, shuffle=True, random_state=1)

# Convert back to DataFrames if needed (assuming X_train is still a numpy array)
X_train = pd.DataFrame(X_train)
y_train = pd.DataFrame(y_train)

# Result collection lists
ACC_collection = []
BACC_collection = []
Sn_collection = []
Sp_collection = []
MCC_collection = []
AUC_collection = []

# Function to train the model for each fold
def ESM_CNN(X_train_CV, y_train_CV, X_valid_CV, y_valid_CV):
    # Train the model with the training fold
    model = train_model(X_train_CV, y_train_CV, X_valid_CV, y_valid_CV)
    return model

# Cross-validation loop
for train_index, test_index in kf.split(y_train):
    X_train_CV, X_valid_CV = X_train.iloc[train_index, :], X_train.iloc[test_index, :]
    y_train_CV, y_valid_CV = y_train.iloc[train_index], y_train.iloc[test_index]

    # Train the model for this fold
    model = ESM_CNN(X_train_CV, y_train_CV, X_valid_CV, y_valid_CV)
    
    # Load the best model
    saved_model = load_model('best_model_1280.h5')
    
    # Predict probabilities
    predicted_probabilities = saved_model.predict(X_valid_CV, batch_size=1)
    
    # Convert probabilities to class predictions
    predicted_class = (predicted_probabilities > 0.5).astype(int)
    
    # Compute confusion matrix
    y_true = y_valid_CV.values
    TN, FP, FN, TP = confusion_matrix(y_true, predicted_class).ravel()
    
    # Compute metrics
    ACC = (TP + TN) / (TP + TN + FP + FN)
    Sn = TP / (TP + FN)
    Sp = TN / (TN + FP)
    MCC = (TP * TN - FP * FN) / math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
    BACC = 0.5 * (Sn + Sp)
    AUC = roc_auc_score(y_true, predicted_probabilities)
    
    # Append metrics to collection lists
    ACC_collection.append(ACC)
    Sn_collection.append(Sn)
    Sp_collection.append(Sp)
    MCC_collection.append(MCC)
    BACC_collection.append(BACC)
    AUC_collection.append(AUC)

# Display the results for each fold
print("Average Accuracy:", np.mean(ACC_collection))
print("Average Balanced Accuracy:", np.mean(BACC_collection))
print("Average Sensitivity (Sn):", np.mean(Sn_collection))
print("Average Specificity (Sp):", np.mean(Sp_collection))
print("Average MCC:", np.mean(MCC_collection))
print("Average AUC:", np.mean(AUC_collection))

Epoch 1/100
Epoch 1: val_accuracy improved from -inf to 0.24601, saving model to best_model_1280.h5


  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.24601
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.24601
Epoch 4/100
Epoch 4: val_accuracy improved from 0.24601 to 0.26424, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.26424 to 0.86560, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.86560 to 0.90888, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.90888
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90888
Epoch 9/100
Epoch 9: val_accuracy improved from 0.90888 to 0.91344, saving model to best_model_1280.h5
Epoch 10/100
Epoch 10: val_accuracy improved from 0.91344 to 0.92483, saving model to best_model_1280.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.92483
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.92483
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.92483
Epoch 14/100
Epoch 14: val_accuracy 

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26424
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26424
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26424 to 0.62187, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.62187 to 0.90661, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy did not improve from 0.90661
Epoch 7/100
Epoch 7: val_accuracy improved from 0.90661 to 0.91800, saving model to best_model_1280.h5
Epoch 8/100
Epoch 8: val_accuracy improved from 0.91800 to 0.92027, saving model to best_model_1280.h5
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.92027
Epoch 10/100
Epoch 10: val_accuracy improved from 0.92027 to 0.92483, saving model to best_model_1280.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.92483
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.92483
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.92483
Epoch 14/100
Epoch 14: val_accuracy 

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26424
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26424
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26424 to 0.51025, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.51025 to 0.91344, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.91344 to 0.91572, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.91572
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.91572
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.91572
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.91572
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.91572
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91572
Epoch 13/100
Epoch 13: val_accuracy improved from 0.91572 to 0.92711, saving model to best_model_1280.h5
Epoch 14/100
Epoch 14: val_accuracy improved from 0.92711 to 0.92938, saving

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.24429
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.24429
Epoch 4/100
Epoch 4: val_accuracy improved from 0.24429 to 0.26712, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.26712 to 0.85845, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.85845 to 0.89726, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy improved from 0.89726 to 0.90639, saving model to best_model_1280.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90639
Epoch 9/100
Epoch 9: val_accuracy improved from 0.90639 to 0.91324, saving model to best_model_1280.h5
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.91324
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.91324
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91324
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91324
Epoch 14/100
Epoch 14: val_accuracy 

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.24886
Epoch 3/100
Epoch 3: val_accuracy improved from 0.24886 to 0.64840, saving model to best_model_1280.h5
Epoch 4/100
Epoch 4: val_accuracy improved from 0.64840 to 0.80137, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy did not improve from 0.80137
Epoch 6/100
Epoch 6: val_accuracy improved from 0.80137 to 0.87443, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy improved from 0.87443 to 0.88128, saving model to best_model_1280.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.88128
Epoch 9/100
Epoch 9: val_accuracy improved from 0.88128 to 0.90639, saving model to best_model_1280.h5
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.90639
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.90639
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.90639
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.90639
Epoch 14/100
Epoch 14: val_accuracy 

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26256
Epoch 3/100
Epoch 3: val_accuracy improved from 0.26256 to 0.61872, saving model to best_model_1280.h5
Epoch 4/100
Epoch 4: val_accuracy did not improve from 0.61872
Epoch 5/100
Epoch 5: val_accuracy improved from 0.61872 to 0.89498, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy did not improve from 0.89498
Epoch 7/100
Epoch 7: val_accuracy improved from 0.89498 to 0.89954, saving model to best_model_1280.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.89954
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.89954
Epoch 10/100
Epoch 10: val_accuracy improved from 0.89954 to 0.90639, saving model to best_model_1280.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.90639
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.90639
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.90639
Epoch 14/100
Epoch 14: val_accuracy did not improve from 0.90639
Epoch 15/10

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26027
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26027
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26027 to 0.64155, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.64155 to 0.73744, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.73744 to 0.90639, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.90639
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90639
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90639
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.90639
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.90639
Epoch 12/100
Epoch 12: val_accuracy improved from 0.90639 to 0.91553, saving model to best_model_1280.h5
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91553
Epoch 14/100
Epoch 14: val_accuracy improved from 0.91553 to 0.91781, saving

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26941
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26941
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26941 to 0.31963, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.31963 to 0.85845, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.85845 to 0.90183, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.90183
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90183
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90183
Epoch 10/100
Epoch 10: val_accuracy improved from 0.90183 to 0.92694, saving model to best_model_1280.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.92694
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.92694
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.92694
Epoch 14/100
Epoch 14: val_accuracy did not improve from 0.92694
Epoch 15/10

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.23744
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.23744
Epoch 4/100
Epoch 4: val_accuracy improved from 0.23744 to 0.71918, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy did not improve from 0.71918
Epoch 6/100
Epoch 6: val_accuracy improved from 0.71918 to 0.89954, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy improved from 0.89954 to 0.90868, saving model to best_model_1280.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90868
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90868
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.90868
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.90868
Epoch 12/100
Epoch 12: val_accuracy improved from 0.90868 to 0.91096, saving model to best_model_1280.h5
Epoch 13/100
Epoch 13: val_accuracy improved from 0.91096 to 0.91781, saving model to best_model_1280.h5
Epoch 14/100
Epoch 14: val_accuracy 

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26941
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26941
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26941 to 0.52055, saving model to best_model_1280.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.52055 to 0.86986, saving model to best_model_1280.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.86986 to 0.90411, saving model to best_model_1280.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.90411
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90411
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90411
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.90411
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.90411
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.90411
Epoch 13/100
Epoch 13: val_accuracy improved from 0.90411 to 0.90639, saving model to best_model_1280.h5
Epoch 14/100
Epoch 14: val_accuracy did not improve from 0.90639
Epoch 15/10

In [3]:
# Display the results for each fold with mean and standard deviation
print("Accuracy: Mean =", np.mean(ACC_collection), ", Std =", np.std(ACC_collection))
print("Balanced Accuracy: Mean =", np.mean(BACC_collection), ", Std =", np.std(BACC_collection))
print("Sensitivity (Sn): Mean =", np.mean(Sn_collection), ", Std =", np.std(Sn_collection))
print("Specificity (Sp): Mean =", np.mean(Sp_collection), ", Std =", np.std(Sp_collection))
print("MCC: Mean =", np.mean(MCC_collection), ", Std =", np.std(MCC_collection))
print("AUC: Mean =", np.mean(AUC_collection), ", Std =", np.std(AUC_collection))

Accuracy: Mean = 0.9281310783120625 , Std = 0.009373552754328912
Balanced Accuracy: Mean = 0.8985817327633564 , Std = 0.015307497729815803
Sensitivity (Sn): Mean = 0.8377980352766551 , Std = 0.03415118325399747
Specificity (Sp): Mean = 0.9593654302500578 , Std = 0.013860292931974307
MCC: Mean = 0.809294926429111 , Std = 0.027758012403472886
AUC: Mean = 0.9502147088294292 , Std = 0.012468456811801523


In [30]:
# Evaluate on the test dataset
predicted_probas_test = saved_model.predict(X_test, batch_size=32)
best_threshold_test, best_mcc_test = optimize_threshold(y_test, predicted_probas_test)
predicted_classes_test = (predicted_probas_test > best_threshold_test).astype(int)

# Calculate metrics for the test dataset with optimized threshold
accuracy_test = accuracy_score(y_test, predicted_classes_test)
sensitivity_test = recall_score(y_test, predicted_classes_test)  # Sensitivity (Recall)
TN_test, FP_test, FN_test, TP_test = confusion_matrix(y_test, predicted_classes_test).ravel()
specificity_test = TN_test / (TN_test + FP_test)  # Corrected Specificity calculation
MCC_test = matthews_corrcoef(y_test, predicted_classes_test)
auc_test = roc_auc_score(y_test, predicted_classes_test)

# Compute the correct balanced accuracy
balanced_accuracy_test = (sensitivity_test + specificity_test) / 2

# Print the adjusted results for the test dataset
print("\nOptimized Test Dataset Results:")
print(f"Accuracy (ACC): {accuracy_test}")
print(f"Balanced Accuracy (BACC): {balanced_accuracy_test}")
print(f"Sensitivity (Sn): {sensitivity_test}")
print(f"Specificity (Sp): {specificity_test}")
print(f"MCC: {MCC_test}")
print(f"AUC: {auc_test}")
print(f"True Positives (TP): {TP_test}")
print(f"False Positives (FP): {FP_test}")
print(f"True Negatives (TN): {TN_test}")
print(f"False Negatives (FN): {FN_test}")

# Print the total positive and total negative
total_positive = np.sum(y_test)
total_negative = len(y_test) - total_positive
print(f"Total Positive: {total_positive}")
print(f"Total Negative: {total_negative}")


Optimized Test Dataset Results:
Accuracy (ACC): 0.9288321167883211
Balanced Accuracy (BACC): 0.8929570280206153
Sensitivity (Sn): 0.8201438848920863
Specificity (Sp): 0.9657701711491442
MCC: 0.8081918754774996
AUC: 0.8929570280206153
True Positives (TP): 228
False Positives (FP): 28
True Negatives (TN): 790
False Negatives (FN): 50
Total Positive: 278
Total Negative: 818


In [31]:
# Evaluate on the external dataset (KELM)
dataset_external = pd.read_csv('kelm_dataset.csv', na_filter=False)
X_external_data_name = 'kelm_dataset_esm2_t33_650M_UR50D_unified_1280_dimension.csv'
X_external_data = pd.read_csv(X_external_data_name, header=0, index_col=0, delimiter=',')
X_external = np.array(X_external_data)
y_external = np.array(dataset_external['label'])

# Normalize the external dataset
X_external_normalized = scaler.transform(X_external)

# Predict probabilities for external dataset
predicted_probas_ext = saved_model.predict(X_external_normalized, batch_size=32)
best_threshold_ext, best_mcc_ext = optimize_threshold(y_external, predicted_probas_ext)
predicted_classes_ext = (predicted_probas_ext > best_threshold_ext).astype(int)

# Calculate metrics for the external dataset with optimized threshold
accuracy_ext = accuracy_score(y_external, predicted_classes_ext)
sensitivity_ext = recall_score(y_external, predicted_classes_ext)  # Sensitivity (Recall)
TN_ext, FP_ext, FN_ext, TP_ext = confusion_matrix(y_external, predicted_classes_ext).ravel()
specificity_ext = TN_ext / (TN_ext + FP_ext)  # Corrected Specificity calculation
MCC_ext = matthews_corrcoef(y_external, predicted_classes_ext)
auc_ext = roc_auc_score(y_external, predicted_classes_ext)

# Compute the correct balanced accuracy
balanced_accuracy_ext = (sensitivity_ext + specificity_ext) / 2

# Print the adjusted results for the external dataset
print("\nOptimized External Dataset (KELM) Results:")
print(f"Accuracy (ACC): {accuracy_ext}")
print(f"Balanced Accuracy (BACC): {balanced_accuracy_ext}")
print(f"Sensitivity (Sn): {sensitivity_ext}")
print(f"Specificity (Sp): {specificity_ext}")
print(f"MCC: {MCC_ext}")
print(f"AUC: {auc_ext}")
print(f"True Positives (TP): {TP_ext}")
print(f"False Positives (FP): {FP_ext}")
print(f"True Negatives (TN): {TN_ext}")
print(f"False Negatives (FN): {FN_ext}")

# Print the total positive and total negative
total_positive_ext = np.sum(y_external)
total_negative_ext = len(y_external) - total_positive_ext
print(f"Total Positive: {total_positive_ext}")
print(f"Total Negative: {total_negative_ext}")


Optimized External Dataset (KELM) Results:
Accuracy (ACC): 0.8958333333333334
Balanced Accuracy (BACC): 0.8958333333333333
Sensitivity (Sn): 0.84375
Specificity (Sp): 0.9479166666666666
MCC: 0.7959970056457051
AUC: 0.8958333333333333
True Positives (TP): 81
False Positives (FP): 5
True Negatives (TN): 91
False Negatives (FN): 15
Total Positive: 96
Total Negative: 96
