# MODEL IMPLEMENTATION
This notebook presents the implementation and training process of a machine learning model (Neural Network).

In this case, based on the results already obtained from the XGBoost model, we used only the configuration with the data integrated with the new features and without the use of SMOTE.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    classification_report,
    precision_score,
    recall_score,
    accuracy_score,
    f1_score,
    balanced_accuracy_score,
    roc_auc_score,
    matthews_corrcoef,
    confusion_matrix,
    hamming_loss)
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.callbacks import EarlyStopping
from keras.metrics import AUC, Precision, Recall, F1Score
from tensorflow.keras.models import load_model, clone_model
import joblib

In [None]:
# Set Pandas to display all columns
pd.set_option('display.max_columns', None)

In [None]:
DATA_FOLDER = "/content/drive/MyDrive/Structural Bioinfo PROJECT/datasets"
TARGETS = ['HBOND', 'IONIC', 'PICATION', 'PIHBOND', 'PIPISTACK', 'SSBOND', 'VDW']
EXPERIMENT_NAME = "new_no_smote"
TRAIN_FILE = "train_set_new.tsv"
VAL_FILE = "val_set_new.tsv"
TEST_FILE = "test_set_new.tsv"
OUTPUT_FOLDER = os.path.join(DATA_FOLDER, f"results_{EXPERIMENT_NAME}")
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Data Loading

In [None]:
def load_tsv_as_df(folder, filename):
    if not filename.endswith('.tsv'):
        filename += '.tsv'
    path = os.path.join(folder, filename)
    df = pd.read_csv(path, sep='\t')
    return df

In [None]:
def load_dataset(folder, train_file, val_file, test_file, target_cols):
    train_df = load_tsv_as_df(folder, train_file)
    val_df = load_tsv_as_df(folder, val_file)
    test_df = load_tsv_as_df(folder, test_file)
    X_train, y_train = train_df.drop(columns=target_cols), train_df[target_cols]
    X_val, y_val = val_df.drop(columns=target_cols), val_df[target_cols]
    X_test, y_test = test_df.drop(columns=target_cols), test_df[target_cols]
    for df in [X_train, X_val, X_test]:
        df['same_chain'] = df['same_chain'].astype(int)
    return X_train, y_train, X_val, y_val, X_test, y_test

In [None]:
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset(DATA_FOLDER, TRAIN_FILE, VAL_FILE, TEST_FILE, TARGETS)

In [None]:
# List of continuous columns
continuous_cols = [
    's_resi', 's_rsa', 's_phi', 's_psi', 's_a1', 's_a2', 's_a3', 's_a4', 's_a5',
    't_resi', 't_rsa', 't_phi', 't_psi', 't_a1', 't_a2', 't_a3', 't_a4', 't_a5',
    'delta_rsa', 'delta_atchley_1', 'delta_atchley_2', 'delta_atchley_3', 'delta_atchley_4', 'delta_atchley_5',
    'ca_distance', 's_centroid_x', 's_centroid_y', 't_centroid_x', 't_centroid_y'
]

# Scale only continuous columns
scaler = StandardScaler()
X_train_scaled = X_train.copy()
X_val_scaled = X_val.copy()
X_test_scaled = X_test.copy()

# Fit scaler only on training data
X_train_scaled[continuous_cols] = scaler.fit_transform(X_train[continuous_cols])
X_val_scaled[continuous_cols] = scaler.transform(X_val[continuous_cols])
X_test_scaled[continuous_cols] = scaler.transform(X_test[continuous_cols])

# Save the scaler
scaler_filename = os.path.join(folder_path, f"{model_name}_scaler.pkl")
joblib.dump(scaler, scaler_filename)
print(f"Scaler saved to: {scaler_filename}")

In [None]:
def make_dataset(X, y, batch_size=64, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices((X.values.astype(np.float32), y.values.astype(np.float32)))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(X))
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

In [None]:
def compute_class_weights(y_df):
    weights = {}
    for i, col in enumerate(y_df.columns):
        w = compute_class_weight('balanced', classes=np.array([0, 1]), y=y_df[col])
        weights[col] = {0: w[0], 1: w[1]}
    return weights

In [None]:
class_weights = compute_class_weights(y_train)

# Neural Network
This network has three hidden layers, making it deeper than a simple perceptron but not quite a “deep neural network” in the classical sense. Also, here we used a *One-vs-All* approach, and we implemented class weights during training - as for XGBoost model - to handle class imbalance effectively.

In [None]:
def make_model(input_dim):
    model = models.Sequential([
        layers.Input(shape=(input_dim,)),

        layers.Dense(256),
        layers.BatchNormalization(),
        layers.LeakyReLU(),

        layers.Dense(128),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Dropout(0.3),

        layers.Dense(64),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Dropout(0.2),

        layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
                  loss='binary_crossentropy',
                  metrics=['accuracy',
                           keras.metrics.AUC(name='auc'),
                           keras.metrics.F1Score(name='f1_score'),
                           keras.metrics.Precision(name='precision'),
                           keras.metrics.Recall(name='recall')
                  ])

    # Implement early stopping
    early_stopping = EarlyStopping(monitor='val_loss',
                                  patience=5,
                                  restore_best_weights=True)
    return model, early_stopping

# Performance Evaluation Metrics
In this section, the following metrics are computed for each label:
* **Accuracy** (both standard and balanced);

* **Precision** (macro version);

* **Recall** (macro version);

* **F1 score** (macro version);

* **MCC** (Matthews Correlation Coefficient);

* **ROC AUC score** (macro version);

* **Confusion matrix**.

In [None]:
results = {}
for i, target in enumerate(TARGETS):
    print(f"\nTraining model for label: {target}")
    y_train_i, y_val_i, y_test_i = y_train[target], y_val[target], y_test[target]

    model, early_stopping = make_model(X_train_scaled.shape[1])

    # Convert to datasets
    train_ds = make_dataset(X_train_scaled, y_train[[target]], batch_size=64, shuffle=True)
    val_ds = make_dataset(X_val_scaled, y_val[[target]], batch_size=64)
    test_ds = make_dataset(X_test_scaled, y_test[[target]], batch_size=64)

    # Train with class weights
    cw = class_weights[target]
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=20,
        class_weight=cw,
        callbacks=[early_stopping],
        batch_size=64,
        shuffle=True
    )

    # Save model
    model.save(os.path.join(OUTPUT_FOLDER, f"model_{target}.h5"))

    test_loss, test_accuracy, test_auc, test_f1_score, test_precision, test_recall = model.evaluate(test_ds)

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"test AUC: {test_auc:.4f}")
    print(f"Test Recall: {test_recall:.4f}")
    print(f"Test F1 Score: {test_f1_score:.4f}")

    y_true = []
    y_pred = []

    for batch_x, batch_y in test_ds:
        preds = model.predict(batch_x, verbose=0)
        y_true.append(batch_y.numpy())
        y_pred.append(preds)

    y_true = np.vstack(y_true)
    y_pred = np.vstack(y_pred)
    y_pred_binary = (y_pred >= 0.5).astype(int)

    # Metrics
    acc = accuracy_score(y_true, y_pred_binary)
    bal_acc = balanced_accuracy_score(y_true, y_pred_binary)
    prec = precision_score(y_true, y_pred_binary, average='macro')
    rec = recall_score(y_true, y_pred_binary, average='macro')
    f1 = f1_score(y_true, y_pred_binary, average='macro')
    auc = roc_auc_score(y_true, y_pred, average='macro')
    mcc = matthews_corrcoef(y_true.ravel(), y_pred_binary.ravel())
    cm = confusion_matrix(y_true, y_pred_binary)

    results[target] = {
        'accuracy': acc,
        'balanced_accuracy': bal_acc,
        'precision': prec,
        'recall': rec,
        'f1_score': f1,
        'roc_auc': auc,
        'mcc': mcc,
        'confusion_matrix': cm.tolist()
    }

    print(f"Metrics for {target}:")
    print(f"Accuracy: {acc:.4f}, Balanced Acc: {bal_acc:.4f}, F1: {f1:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, AUC: {auc:.4f}, MCC: {mcc:.4f}")

    # Plot confusion matrix
    plt.figure(figsize=(4,3))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.title(f"{target} Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_FOLDER, f"nn_cm_{target}.png"))
    plt.close()

# Save metrics
pd.DataFrame(results).T.to_csv(os.path.join(OUTPUT_FOLDER, f"{EXPERIMENT_NAME}_metrics_nn.csv"))


Training model for label: HBOND
Epoch 1/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 4ms/step - accuracy: 0.6904 - auc: 0.7485 - f1_score: 0.8434 - loss: 0.5901 - precision: 0.8508 - recall: 0.6973 - val_accuracy: 0.7869 - val_auc: 0.8755 - val_f1_score: 0.8421 - val_loss: 0.4512 - val_precision: 0.9093 - val_recall: 0.7854
Epoch 2/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 4ms/step - accuracy: 0.7788 - auc: 0.8617 - f1_score: 0.8425 - loss: 0.4637 - precision: 0.9021 - recall: 0.7809 - val_accuracy: 0.7964 - val_auc: 0.8864 - val_f1_score: 0.8421 - val_loss: 0.4251 - val_precision: 0.9143 - val_recall: 0.7945
Epoch 3/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 4ms/step - accuracy: 0.7873 - auc: 0.8746 - f1_score: 0.8431 - loss: 0.4419 - precision: 0.9096 - recall: 0.7862 - val_accuracy: 0.8051 - val_auc: 0.8906 - val_f1_score: 0.8421 - val_loss: 0.4103 - val_precision: 0.9122 - val_recall: 0.809



[1m4252/4252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 3ms/step - accuracy: 0.8297 - auc: 0.9155 - f1_score: 0.8436 - loss: 0.3603 - precision: 0.9264 - recall: 0.8327
Test Loss: 0.3620
Test Accuracy: 0.8297
Test Precision: 0.9244
test AUC: 0.9147
Test Recall: 0.8330
Test F1 Score: 0.8400
Metrics for HBOND:
Accuracy: 0.8297, Balanced Acc: 0.8271, F1: 0.8016, Precision: 0.7882, Recall: 0.8271, AUC: 0.9147, MCC: 0.6140

Training model for label: IONIC
Epoch 1/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 5ms/step - accuracy: 0.8713 - auc: 0.9757 - f1_score: 0.0482 - loss: 0.2082 - precision: 0.1887 - recall: 0.9900 - val_accuracy: 0.9492 - val_auc: 0.9879 - val_f1_score: 0.0466 - val_loss: 0.1530 - val_precision: 0.3194 - val_recall: 0.9994
Epoch 2/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 4ms/step - accuracy: 0.9491 - auc: 0.9866 - f1_score: 0.0478 - loss: 0.0987 - precision: 0.3246 - recall: 0.9978 - val_accuracy



[1m4252/4252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 3ms/step - accuracy: 0.9525 - auc: 0.9892 - f1_score: 0.0450 - loss: 0.1136 - precision: 0.3262 - recall: 0.9967
Test Loss: 0.1180
Test Accuracy: 0.9509
Test Precision: 0.3231
test AUC: 0.9889
Test Recall: 0.9961
Test F1 Score: 0.0460
Metrics for IONIC:
Accuracy: 0.9509, Balanced Acc: 0.9729, F1: 0.7311, Precision: 0.6615, Recall: 0.9729, AUC: 0.9889, MCC: 0.5527

Training model for label: PICATION
Epoch 1/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 4ms/step - accuracy: 0.8392 - auc: 0.9680 - f1_score: 0.0119 - loss: 0.2375 - precision: 0.0461 - recall: 0.9818 - val_accuracy: 0.9741 - val_auc: 0.9924 - val_f1_score: 0.0118 - val_loss: 0.1034 - val_precision: 0.1872 - val_recall: 1.0000
Epoch 2/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 4ms/step - accuracy: 0.9741 - auc: 0.9915 - f1_score: 0.0122 - loss: 0.0655 - precision: 0.1916 - recall: 0.9972 - val_accur



[1m4252/4252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 3ms/step - accuracy: 0.9746 - auc: 0.9926 - f1_score: 0.0126 - loss: 0.0797 - precision: 0.1981 - recall: 0.9918
Test Loss: 0.0798
Test Accuracy: 0.9747
Test Precision: 0.2006
test AUC: 0.9936
Test Recall: 0.9937
Test F1 Score: 0.0127
Metrics for PICATION:
Accuracy: 0.9747, Balanced Acc: 0.9841, F1: 0.6605, Precision: 0.6003, Recall: 0.9841, AUC: 0.9936, MCC: 0.4407

Training model for label: PIHBOND
Epoch 1/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 5ms/step - accuracy: 0.6733 - auc: 0.8715 - f1_score: 0.0025 - loss: 0.4672 - precision: 0.0037 - recall: 0.9102 - val_accuracy: 0.8371 - val_auc: 0.9316 - val_f1_score: 0.0022 - val_loss: 0.3852 - val_precision: 0.0061 - val_recall: 0.9146
Epoch 2/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 4ms/step - accuracy: 0.8455 - auc: 0.9435 - f1_score: 0.0025 - loss: 0.2959 - precision: 0.0075 - recall: 0.9299 - val_acc



[1m4252/4252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 3ms/step - accuracy: 0.9189 - auc: 0.9597 - f1_score: 0.0026 - loss: 0.1866 - precision: 0.0134 - recall: 0.8576
Test Loss: 0.1901
Test Accuracy: 0.9171
Test Precision: 0.0131
test AUC: 0.9563
Test Recall: 0.8543
Test F1 Score: 0.0026
Metrics for PIHBOND:
Accuracy: 0.9171, Balanced Acc: 0.8857, F1: 0.4912, Precision: 0.5065, Recall: 0.8857, AUC: 0.9565, MCC: 0.0998

Training model for label: PIPISTACK
Epoch 1/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 5ms/step - accuracy: 0.9105 - auc: 0.9840 - f1_score: 0.0509 - loss: 0.1462 - precision: 0.3008 - recall: 0.9867 - val_accuracy: 0.9904 - val_auc: 0.9969 - val_f1_score: 0.0520 - val_loss: 0.0469 - val_precision: 0.7346 - val_recall: 1.0000
Epoch 2/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 5ms/step - accuracy: 0.9904 - auc: 0.9965 - f1_score: 0.0506 - loss: 0.0291 - precision: 0.7289 - recall: 0.9999 - val_ac



[1m4252/4252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 3ms/step - accuracy: 0.9909 - auc: 0.9975 - f1_score: 0.0529 - loss: 0.0376 - precision: 0.7419 - recall: 0.9996
Test Loss: 0.0389
Test Accuracy: 0.9906
Test Precision: 0.7463
test AUC: 0.9974
Test Recall: 0.9995
Test F1 Score: 0.0550
Metrics for PIPISTACK:
Accuracy: 0.9906, Balanced Acc: 0.9949, F1: 0.9248, Precision: 0.8731, Recall: 0.9949, AUC: 0.9976, MCC: 0.8595

Training model for label: SSBOND
Epoch 1/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 5ms/step - accuracy: 0.8401 - auc: 0.9817 - f1_score: 0.0030 - loss: 0.1907 - precision: 0.0153 - recall: 0.9806 - val_accuracy: 0.9995 - val_auc: 0.9998 - val_f1_score: 0.0031 - val_loss: 0.0059 - val_precision: 0.7532 - val_recall: 1.0000
Epoch 2/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 5ms/step - accuracy: 0.9992 - auc: 0.9995 - f1_score: 0.0030 - loss: 0.0047 - precision: 0.6452 - recall: 0.9999 - val_acc



[1m4252/4252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 3ms/step - accuracy: 0.9993 - auc: 0.9986 - f1_score: 0.0023 - loss: 0.0039 - precision: 0.6282 - recall: 0.9988
Test Loss: 0.0043
Test Accuracy: 0.9993
Test Precision: 0.6054
test AUC: 0.9997
Test Recall: 1.0000
Test F1 Score: 0.0022
Metrics for SSBOND:
Accuracy: 0.9993, Balanced Acc: 0.9996, F1: 0.8769, Precision: 0.8027, Recall: 0.9996, AUC: 0.9998, MCC: 0.7778

Training model for label: VDW
Epoch 1/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 5ms/step - accuracy: 0.5830 - auc: 0.6159 - f1_score: 0.6726 - loss: 0.6826 - precision: 0.5907 - recall: 0.5857 - val_accuracy: 0.6562 - val_auc: 0.7266 - val_f1_score: 0.6736 - val_loss: 0.6064 - val_precision: 0.6789 - val_recall: 0.6130
Epoch 2/20
[1m13995/13995[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 4ms/step - accuracy: 0.6490 - auc: 0.7123 - f1_score: 0.6725 - loss: 0.6180 - precision: 0.6686 - recall: 0.6090 - val_accuracy:



[1m4252/4252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 3ms/step - accuracy: 0.6811 - auc: 0.7576 - f1_score: 0.6738 - loss: 0.5747 - precision: 0.7120 - recall: 0.6253
Test Loss: 0.5735
Test Accuracy: 0.6816
Test Precision: 0.7132
test AUC: 0.7589
Test Recall: 0.6269
Test F1 Score: 0.6748
Metrics for VDW:
Accuracy: 0.6816, Balanced Acc: 0.6827, F1: 0.6810, Precision: 0.6846, Recall: 0.6827, AUC: 0.7589, MCC: 0.3673
