In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, recall_score, matthews_corrcoef, roc_auc_score, confusion_matrix
from tensorflow.keras.layers import Input, Dense, Conv1D, MaxPooling1D, Flatten, Dropout, BatchNormalization
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping

# Load the dataset
dataset = pd.read_excel('Final_non_redundant_sequences.xlsx', na_filter=False)
X_data_name = 'prot_t5_xl_bfd_per_protein_embeddings.csv'
X_data = pd.read_csv(X_data_name, header=0, index_col=0, delimiter=',')
X = np.array(X_data)
y = np.array(dataset['label'])

# Split dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)

# Normalize the data
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

def build_model(input_shape):
    input = Input(input_shape)
    x = Conv1D(64, 5, strides=1, padding='same')(input)
    x = BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = MaxPooling1D(2, padding='same')(x)
    x = Dropout(0.25)(x)
    
    x = Conv1D(128, 5, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = MaxPooling1D(2, padding='same')(x)
    x = Dropout(0.25)(x)
    
    x = Flatten()(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=input, outputs=x)
    return model

def step_decay(epoch):
    initial_lr = 0.01
    drop = 0.5
    epochs_drop = 10
    lr = initial_lr * np.power(drop, np.floor((1 + epoch) / epochs_drop))
    return lr

def train_model(X_train, y_train, X_test, y_test):
    input_shape = (1024, 1)
    model = build_model(input_shape)
    
    # Optimizer
    adam = tf.keras.optimizers.Adam(learning_rate=0.001)
    
    model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
    
    lrate = LearningRateScheduler(step_decay)
    early_stop = EarlyStopping(monitor='val_accuracy', patience=20, verbose=1, restore_best_weights=True)
    mc = ModelCheckpoint('best_model_1024-bfd.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
    callbacks_list = [lrate, early_stop, mc]
    
    class_weight = {0: 1, 1: 2}  # Adjust the weights as needed
    
    model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=100, callbacks=callbacks_list, batch_size=32, class_weight=class_weight)
    
    return model

# Train the model
trained_model = train_model(X_train, y_train, X_test, y_test)

# Load the best model
saved_model = load_model('best_model_1024-bfd.h5')

# Function to optimize threshold based on MCC
def optimize_threshold(y_true, y_pred_probas):
    thresholds = np.arange(0.1, 1.0, 0.05)
    best_mcc = -1
    best_threshold = 0.5
    
    for threshold in thresholds:
        y_pred = (y_pred_probas > threshold).astype(int)
        mcc = matthews_corrcoef(y_true, y_pred)
        
        if mcc > best_mcc:
            best_mcc = mcc
            best_threshold = threshold
    
    return best_threshold, best_mcc




Epoch 1/100


Epoch 1: val_accuracy improved from -inf to 0.25365, saving model to best_model_1024-bfd.h5


  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.25365
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.25365
Epoch 4/100
Epoch 4: val_accuracy improved from 0.25365 to 0.52099, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.52099 to 0.87500, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy did not improve from 0.87500
Epoch 7/100
Epoch 7: val_accuracy improved from 0.87500 to 0.89507, saving model to best_model_1024-bfd.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.89507
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.89507
Epoch 10/100
Epoch 10: val_accuracy improved from 0.89507 to 0.89781, saving model to best_model_1024-bfd.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.89781
Epoch 12/100
Epoch 12: val_accuracy improved from 0.89781 to 0.90328, saving model to best_model_1024-bfd.h5
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.90328
Epoch 14/100
Epo

In [2]:
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score, confusion_matrix
import math

k = 10
kf = KFold(n_splits=k, shuffle=True, random_state=1)

# Convert back to DataFrames if needed (assuming X_train is still a numpy array)
X_train = pd.DataFrame(X_train)
y_train = pd.DataFrame(y_train)

# Result collection lists
ACC_collection = []
BACC_collection = []
Sn_collection = []
Sp_collection = []
MCC_collection = []
AUC_collection = []

# Function to train the model for each fold
def ESM_CNN(X_train_CV, y_train_CV, X_valid_CV, y_valid_CV):
    # Train the model with the training fold
    model = train_model(X_train_CV, y_train_CV, X_valid_CV, y_valid_CV)
    return model

# Cross-validation loop
for train_index, test_index in kf.split(y_train):
    X_train_CV, X_valid_CV = X_train.iloc[train_index, :], X_train.iloc[test_index, :]
    y_train_CV, y_valid_CV = y_train.iloc[train_index], y_train.iloc[test_index]

    # Train the model for this fold
    model = ESM_CNN(X_train_CV, y_train_CV, X_valid_CV, y_valid_CV)
    
    # Load the best model
    saved_model = load_model('best_model_1024-bfd.h5')
    
    # Predict probabilities
    predicted_probabilities = saved_model.predict(X_valid_CV, batch_size=1)
    
    # Convert probabilities to class predictions
    predicted_class = (predicted_probabilities > 0.5).astype(int)
    
    # Compute confusion matrix
    y_true = y_valid_CV.values
    TN, FP, FN, TP = confusion_matrix(y_true, predicted_class).ravel()
    
    # Compute metrics
    ACC = (TP + TN) / (TP + TN + FP + FN)
    Sn = TP / (TP + FN)
    Sp = TN / (TN + FP)
    MCC = (TP * TN - FP * FN) / math.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
    BACC = 0.5 * (Sn + Sp)
    AUC = roc_auc_score(y_true, predicted_probabilities)
    
    # Append metrics to collection lists
    ACC_collection.append(ACC)
    Sn_collection.append(Sn)
    Sp_collection.append(Sp)
    MCC_collection.append(MCC)
    BACC_collection.append(BACC)
    AUC_collection.append(AUC)

# Display the results for each fold
print("Average Accuracy:", np.mean(ACC_collection))
print("Average Balanced Accuracy:", np.mean(BACC_collection))
print("Average Sensitivity (Sn):", np.mean(Sn_collection))
print("Average Specificity (Sp):", np.mean(Sp_collection))
print("Average MCC:", np.mean(MCC_collection))
print("Average AUC:", np.mean(AUC_collection))

Epoch 1/100
Epoch 1: val_accuracy improved from -inf to 0.24601, saving model to best_model_1024-bfd.h5


  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.24601
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.24601
Epoch 4/100
Epoch 4: val_accuracy improved from 0.24601 to 0.55809, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.55809 to 0.91344, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.91344 to 0.92027, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.92027
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.92027
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.92027
Epoch 10/100
Epoch 10: val_accuracy improved from 0.92027 to 0.92938, saving model to best_model_1024-bfd.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.92938
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.92938
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.92938
Epoch 14/100
Epoch 14: val_accuracy did not improve from 0.9

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26424
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26424
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26424 to 0.26879, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.26879 to 0.87244, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.87244 to 0.89522, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.89522
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.89522
Epoch 9/100
Epoch 9: val_accuracy improved from 0.89522 to 0.90433, saving model to best_model_1024-bfd.h5
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.90433
Epoch 11/100
Epoch 11: val_accuracy improved from 0.90433 to 0.91344, saving model to best_model_1024-bfd.h5
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91344
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91344
Epoch 14/100
Epo

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26424
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26424
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26424 to 0.63098, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.63098 to 0.74943, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.74943 to 0.89294, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.89294
Epoch 8/100
Epoch 8: val_accuracy improved from 0.89294 to 0.89522, saving model to best_model_1024-bfd.h5
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.89522
Epoch 10/100
Epoch 10: val_accuracy improved from 0.89522 to 0.90433, saving model to best_model_1024-bfd.h5
Epoch 11/100
Epoch 11: val_accuracy improved from 0.90433 to 0.90888, saving model to best_model_1024-bfd.h5
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.90888
Epoch 13/100
Epoch 13: val_accuracy d

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.24429
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.24429
Epoch 4/100
Epoch 4: val_accuracy improved from 0.24429 to 0.26027, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.26027 to 0.66438, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.66438 to 0.89726, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy improved from 0.89726 to 0.90639, saving model to best_model_1024-bfd.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90639
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90639
Epoch 10/100
Epoch 10: val_accuracy improved from 0.90639 to 0.90868, saving model to best_model_1024-bfd.h5
Epoch 11/100
Epoch 11: val_accuracy improved from 0.90868 to 0.91553, saving model to best_model_1024-bfd.h5
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91553
Epoch 13/100
Epoch 13: val_accuracy i

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.23973
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.23973
Epoch 4/100
Epoch 4: val_accuracy did not improve from 0.23973
Epoch 5/100
Epoch 5: val_accuracy improved from 0.23973 to 0.70091, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.70091 to 0.89269, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.89269
Epoch 8/100
Epoch 8: val_accuracy improved from 0.89269 to 0.90639, saving model to best_model_1024-bfd.h5
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90639
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.90639
Epoch 11/100
Epoch 11: val_accuracy improved from 0.90639 to 0.91096, saving model to best_model_1024-bfd.h5
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91096
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91096
Epoch 14/100
Epoch 14: val_accuracy did not improve from 0.9

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26256
Epoch 3/100
Epoch 3: val_accuracy improved from 0.26256 to 0.30137, saving model to best_model_1024-bfd.h5
Epoch 4/100
Epoch 4: val_accuracy improved from 0.30137 to 0.63470, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.63470 to 0.78767, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.78767 to 0.88813, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy improved from 0.88813 to 0.90183, saving model to best_model_1024-bfd.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90183
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90183
Epoch 10/100
Epoch 10: val_accuracy improved from 0.90183 to 0.90639, saving model to best_model_1024-bfd.h5
Epoch 11/100
Epoch 11: val_accuracy improved from 0.90639 to 0.91553, saving model to best_model_1024-bfd.h5
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26027
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26027
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26027 to 0.50685, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.50685 to 0.89954, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy did not improve from 0.89954
Epoch 7/100
Epoch 7: val_accuracy improved from 0.89954 to 0.90411, saving model to best_model_1024-bfd.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90411
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90411
Epoch 10/100
Epoch 10: val_accuracy improved from 0.90411 to 0.91781, saving model to best_model_1024-bfd.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.91781
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91781
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91781
Epoch 14/100
Epoch 14: val_accuracy did not improve from 0.9

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26941
Epoch 3/100
Epoch 3: val_accuracy improved from 0.26941 to 0.30822, saving model to best_model_1024-bfd.h5
Epoch 4/100
Epoch 4: val_accuracy improved from 0.30822 to 0.71233, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.71233 to 0.89954, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.89954 to 0.91324, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy did not improve from 0.91324
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.91324
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.91324
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.91324
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.91324
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91324
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91324
Epoch 14/100
Epoch 14: val_accuracy did not improve from 0.9

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.23744
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.23744
Epoch 4/100
Epoch 4: val_accuracy improved from 0.23744 to 0.41096, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.41096 to 0.84932, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.84932 to 0.89954, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy improved from 0.89954 to 0.90411, saving model to best_model_1024-bfd.h5
Epoch 8/100
Epoch 8: val_accuracy did not improve from 0.90411
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90411
Epoch 10/100
Epoch 10: val_accuracy did not improve from 0.90411
Epoch 11/100
Epoch 11: val_accuracy improved from 0.90411 to 0.91324, saving model to best_model_1024-bfd.h5
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91324
Epoch 13/100
Epoch 13: val_accuracy did not improve from 0.91324
Epoch 14/100
Epo

  saving_api.save_model(


Epoch 2/100
Epoch 2: val_accuracy did not improve from 0.26941
Epoch 3/100
Epoch 3: val_accuracy did not improve from 0.26941
Epoch 4/100
Epoch 4: val_accuracy improved from 0.26941 to 0.28082, saving model to best_model_1024-bfd.h5
Epoch 5/100
Epoch 5: val_accuracy improved from 0.28082 to 0.82877, saving model to best_model_1024-bfd.h5
Epoch 6/100
Epoch 6: val_accuracy improved from 0.82877 to 0.88128, saving model to best_model_1024-bfd.h5
Epoch 7/100
Epoch 7: val_accuracy improved from 0.88128 to 0.89498, saving model to best_model_1024-bfd.h5
Epoch 8/100
Epoch 8: val_accuracy improved from 0.89498 to 0.90411, saving model to best_model_1024-bfd.h5
Epoch 9/100
Epoch 9: val_accuracy did not improve from 0.90411
Epoch 10/100
Epoch 10: val_accuracy improved from 0.90411 to 0.91553, saving model to best_model_1024-bfd.h5
Epoch 11/100
Epoch 11: val_accuracy did not improve from 0.91553
Epoch 12/100
Epoch 12: val_accuracy did not improve from 0.91553
Epoch 13/100
Epoch 13: val_accuracy d

In [3]:
# Display the results for each fold with mean and standard deviation
print("Accuracy: Mean =", np.mean(ACC_collection), ", Std =", np.std(ACC_collection))
print("Balanced Accuracy: Mean =", np.mean(BACC_collection), ", Std =", np.std(BACC_collection))
print("Sensitivity (Sn): Mean =", np.mean(Sn_collection), ", Std =", np.std(Sn_collection))
print("Specificity (Sp): Mean =", np.mean(Sp_collection), ", Std =", np.std(Sp_collection))
print("MCC: Mean =", np.mean(MCC_collection), ", Std =", np.std(MCC_collection))
print("AUC: Mean =", np.mean(AUC_collection), ", Std =", np.std(AUC_collection))

Accuracy: Mean = 0.923109807470278 , Std = 0.005981142361201247
Balanced Accuracy: Mean = 0.8851181794004332 , Std = 0.014595670392030769
Sensitivity (Sn): Mean = 0.8072993785644561 , Std = 0.03444317129778136
Specificity (Sp): Mean = 0.9629369802364105 , Std = 0.007911726260188342
MCC: Mean = 0.793753922970523 , Std = 0.01815333403849451
AUC: Mean = 0.9428888796269981 , Std = 0.01349407004350602


In [2]:
# Evaluate on the test dataset
predicted_probas_test = saved_model.predict(X_test, batch_size=32)
best_threshold_test, best_mcc_test = optimize_threshold(y_test, predicted_probas_test)
predicted_classes_test = (predicted_probas_test > best_threshold_test).astype(int)

# Calculate metrics for the test dataset with optimized threshold
accuracy_test = accuracy_score(y_test, predicted_classes_test)
sensitivity_test = recall_score(y_test, predicted_classes_test)  # Sensitivity (Recall)
TN_test, FP_test, FN_test, TP_test = confusion_matrix(y_test, predicted_classes_test).ravel()
specificity_test = TN_test / (TN_test + FP_test)  # Corrected Specificity calculation
MCC_test = matthews_corrcoef(y_test, predicted_classes_test)
auc_test = roc_auc_score(y_test, predicted_classes_test)

# Compute the correct balanced accuracy
balanced_accuracy_test = (sensitivity_test + specificity_test) / 2

# Print the adjusted results for the test dataset
print("\nOptimized Test Dataset Results:")
print(f"Accuracy (ACC): {accuracy_test}")
print(f"Balanced Accuracy (BACC): {balanced_accuracy_test}")
print(f"Sensitivity (Sn): {sensitivity_test}")
print(f"Specificity (Sp): {specificity_test}")
print(f"MCC: {MCC_test}")
print(f"AUC: {auc_test}")
print(f"True Positives (TP): {TP_test}")
print(f"False Positives (FP): {FP_test}")
print(f"True Negatives (TN): {TN_test}")
print(f"False Negatives (FN): {FN_test}")

# Print the total positive and total negative
total_positive = np.sum(y_test)
total_negative = len(y_test) - total_positive
print(f"Total Positive: {total_positive}")
print(f"Total Negative: {total_negative}")


Optimized Test Dataset Results:
Accuracy (ACC): 0.9206204379562044
Balanced Accuracy (BACC): 0.8910177481486694
Sensitivity (Sn): 0.8309352517985612
Specificity (Sp): 0.9511002444987775
MCC: 0.7887037324016605
AUC: 0.8910177481486694
True Positives (TP): 231
False Positives (FP): 40
True Negatives (TN): 778
False Negatives (FN): 47
Total Positive: 278
Total Negative: 818


In [3]:
# Evaluate on the external dataset (KELM)
dataset_external = pd.read_csv('kelm_dataset.csv', na_filter=False)
X_external_data_name = 'prot_bfd_per_protein_embeddings_kelm_dataset.csv'
X_external_data = pd.read_csv(X_external_data_name, header=0, index_col=0, delimiter=',')
X_external = np.array(X_external_data)
y_external = np.array(dataset_external['label'])

# Normalize the external dataset
X_external_normalized = scaler.transform(X_external)

# Predict probabilities for external dataset
predicted_probas_ext = saved_model.predict(X_external_normalized, batch_size=32)
best_threshold_ext, best_mcc_ext = optimize_threshold(y_external, predicted_probas_ext)
predicted_classes_ext = (predicted_probas_ext > best_threshold_ext).astype(int)

# Calculate metrics for the external dataset with optimized threshold
accuracy_ext = accuracy_score(y_external, predicted_classes_ext)
sensitivity_ext = recall_score(y_external, predicted_classes_ext)  # Sensitivity (Recall)
TN_ext, FP_ext, FN_ext, TP_ext = confusion_matrix(y_external, predicted_classes_ext).ravel()
specificity_ext = TN_ext / (TN_ext + FP_ext)  # Corrected Specificity calculation
MCC_ext = matthews_corrcoef(y_external, predicted_classes_ext)
auc_ext = roc_auc_score(y_external, predicted_classes_ext)

# Compute the correct balanced accuracy
balanced_accuracy_ext = (sensitivity_ext + specificity_ext) / 2

# Print the adjusted results for the external dataset
print("\nOptimized External Dataset (KELM) Results:")
print(f"Accuracy (ACC): {accuracy_ext}")
print(f"Balanced Accuracy (BACC): {balanced_accuracy_ext}")
print(f"Sensitivity (Sn): {sensitivity_ext}")
print(f"Specificity (Sp): {specificity_ext}")
print(f"MCC: {MCC_ext}")
print(f"AUC: {auc_ext}")
print(f"True Positives (TP): {TP_ext}")
print(f"False Positives (FP): {FP_ext}")
print(f"True Negatives (TN): {TN_ext}")
print(f"False Negatives (FN): {FN_ext}")

# Print the total positive and total negative
total_positive_ext = np.sum(y_external)
total_negative_ext = len(y_external) - total_positive_ext
print(f"Total Positive: {total_positive_ext}")
print(f"Total Negative: {total_negative_ext}")


Optimized External Dataset (KELM) Results:
Accuracy (ACC): 0.9010416666666666
Balanced Accuracy (BACC): 0.9010416666666666
Sensitivity (Sn): 0.8854166666666666
Specificity (Sp): 0.9166666666666666
MCC: 0.802475262666927
AUC: 0.9010416666666665
True Positives (TP): 85
False Positives (FP): 8
True Negatives (TN): 88
False Negatives (FN): 11
Total Positive: 96
Total Negative: 96
