In [1]:
import os
os.environ["KERAS_BACKEND"] ='tensorflow'
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
from tqdm.auto import tqdm
tqdm.pandas()
import keras 
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
from keras import layers, Model, Input
from keras_hub.src.models.llama.llama_decoder import LlamaTransformerDecoder
from keras_hub.src.models.llama.llama_layernorm import LlamaLayerNorm

  from .autonotebook import tqdm as notebook_tqdm
2026-01-04 06:46:20.206045: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767530780.220261   63792 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767530780.225205   63792 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767530780.239214   63792 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767530780.239225   63792 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767530780.239227   63792

# Data Loading

In [2]:
import h5py
file_path = "../DeepInterAware/data/Yeast/Yeast-ProtT5-Full.h5"

loaded_data = {}
with h5py.File(file_path, 'r') as hf:
    for seq in hf.keys():
        loaded_data[seq] = hf[seq][:]

In [3]:
df = pd.read_csv("https://raw.githubusercontent.com/Fengithub/symLMF-PPI/refs/heads/master/datasets/S.cerevisiae-benchmark/pros_AB.txt",sep ='\t')
df.head()

Unnamed: 0,Protein_A_id,Protein_B_id,Protein_A_sequence,Protein_B_sequence,Protein_A_idx,Protein_B_idx,Interaction
0,P16649,P14922,MTASVSNTQNKLNELLDAIRQEFLQVSQEANTYRLQNQKDYDFKMN...,MNPGGEQTIMEQPAQQQQQQQQQQQQQQQQAAVPQQPLDPLTQSTA...,0,577,1
1,P07269,P22035,MMEEFSYDHDFNTHFATDLDYLQHDQQQQQQQQHDQQHNQQQQPQP...,MSNISTKDIRKSKPKRGSGFDLLEVTESLGYQTHRKNGRNSWSKDD...,1,1598,1
2,P33418,P50278,MLERIQQLVNAVNDPRSDVATKRQAIELLNGIKSSENALEIFISLV...,MTTTVPKVFAFHEFAGVAEAVADHVIHAQNSALKKGKVSRSTQMSG...,2,1599,1
3,P27705,P14922,MQSPYPMTQVSNVDDGSLLKESKSKSKVAAKSEAPRPHACPICHRA...,MNPGGEQTIMEQPAQQQQQQQQQQQQQQQQAAVPQQPLDPLTQSTA...,3,577,1
4,P05453,P05453,MSDSNQGNNQQNYQQYSQNGNQQQGNNRYQGYQAYNAQAQPAGGYY...,MSDSNQGNNQQNYQQYSQNGNQQQGNNRYQGYQAYNAQAQPAGGYY...,4,4,1


In [4]:
df['Protein_A_sequence'].str.len().max(),df['Protein_B_sequence'].str.len().max()

(4092, 4910)

In [5]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import Sequence

def process_sequence_tf(x_emb, max_len=512, pad_value=0.0):
    seq_len = x_emb.shape[0]
    
    if seq_len > max_len:
        # truncate
        x_emb = tf.convert_to_tensor(x_emb[:max_len])
        mask = tf.ones([max_len], dtype=tf.float32)
    else:
        # pad
        pad_len = max_len - seq_len
        paddings = [[0, pad_len], [0, 0]]
        x_emb = tf.pad(x_emb, paddings, constant_values=pad_value)
        mask = tf.pad(tf.ones([seq_len], dtype=tf.float32), [[0, pad_len]], constant_values=0.0)
    
    return x_emb, mask

# -----------------------------
# Keras Sequence Loader
# -----------------------------
class DataSequenceLoader(Sequence):
    def __init__(self, df, batch_size=32, shuffle=True, max_len=512, pad_value=0.0):
        self.x1_emb = df["Protein_A_sequence"].values
        self.x2_emb = df["Protein_B_sequence"].values
        self.labels = df["Interaction"].values
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.max_len = max_len
        self.pad_value = pad_value
        self.indices = np.arange(len(df))
        if shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        return int(np.ceil(len(self.indices) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __getitem__(self, idx):
        batch_idx = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]

        x1_list, x2_list, m1_list, m2_list = [], [], [], []
        labels_list = []

        for i in batch_idx:
            emb1 = loaded_data[self.x1_emb[i]].reshape(-1, 1024)
            emb2 = loaded_data[self.x2_emb[i]].reshape(-1, 1024)

            # ---- process embeddings + mask ----
            x1_pad, mask1 = process_sequence_tf(emb1, max_len=self.max_len, pad_value=self.pad_value)
            x2_pad, mask2 = process_sequence_tf(emb2, max_len=self.max_len, pad_value=self.pad_value)

            x1_list.append(x1_pad.numpy())
            x2_list.append(x2_pad.numpy())
            m1_list.append(mask1.numpy())
            m2_list.append(mask2.numpy())
            labels_list.append(self.labels[i])

        # Convert lists to arrays for batch
        x1_batch = np.stack(x1_list, axis=0)
        x2_batch = np.stack(x2_list, axis=0)
        m1_batch = np.stack(m1_list, axis=0)
        m2_batch = np.stack(m2_list, axis=0)
        labels_batch = np.array(labels_list)

        return (x1_batch, x2_batch, m1_batch, m2_batch), labels_batch


In [6]:
# loader  = DataSequenceLoader(df,batch_size=4,shuffle =True)
# for x in loader:
#     break

In [7]:
# -------------------------------------------------------------------
# Hybrid Pooling Layer (Max + Avg)
# -------------------------------------------------------------------
class HybridPooling(layers.Layer):
    def call(self, x):
        max_pooled = keras.ops.max(x, axis=1)
        avg_pooled = keras.ops.mean(x, axis=1)
        return keras.ops.concatenate([max_pooled, avg_pooled], axis=-1)


# -------------------------------------------------------------------
# Conv Block (Conv1D → ReLU → Dropout → MaxPool)
# -------------------------------------------------------------------
def conv_block(x, filters =  100, kernel_sz =20, stride =10, dropout = 0.5):
    x = layers.Conv1D(filters, kernel_sz, strides=stride)(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.MaxPooling1D(pool_size=3, strides=1, padding="same")(x)
    return x


# -------------------------------------------------------------------
# Cross-attention (Query=X1, Key=X2, Value=X2)
# -------------------------------------------------------------------
def cross_attention_block(query, key, value, mask, num_heads =4, key_dim =32):
    attn = layers.MultiHeadAttention(
        num_heads=num_heads, 
        key_dim=key_dim,
        dropout=0.0,
        kernel_initializer="glorot_uniform",
        use_bias=True,
        # flash_attention=None,  # if GPU supports
    )(query, key, value, key_mask=mask, value_mask=mask)
    return attn


# -------------------------------------------------------------------
# LLaMA Self-Attention Block
# -------------------------------------------------------------------
def llama_self_attention(x, mask, hidden_dim =100, num_heads = 4):
    # single LLaMA decoder layer
    llama = LlamaTransformerDecoder(
        intermediate_dim=hidden_dim * 4,
        # num_heads=num_heads,
        num_query_heads=8,
        num_key_value_heads=2,
        dropout=0.0,
        layer_norm_epsilon=1e-5,
        activation="silu"
    )
    return llama(x, decoder_padding_mask=mask)

# -------------------------------------------------------------------
# Functional API Model (PyTorch → Keras Conversion)
# -------------------------------------------------------------------
def build_model(
    input_dim=640,
    conv_out=100,
    kernel_sz=20,
    stride=10,
    heads=4,
    d_dim=32,
    drop_pool=0.5,
    drop_linear=0.3
):
    # Inputs
    inp1 = Input((None, 1024))
    inp2 = Input((None, 1024))
    mask1 = Input((None, ))
    mask2 = Input((None, ))

    # ---------------------------------------------------
    # 1) Convolutional features
    # ---------------------------------------------------
    # p1 = conv_block(inp1, conv_out, kernel_sz, stride, drop_pool)
    # p2 = conv_block(inp2, conv_out, kernel_sz, stride, drop_pool)
    x_dim = 384
    p1 = layers.Dense(x_dim, use_bias=False,name='stem1')(inp1)
    p1 = layers.BatchNormalization(momentum=0.95,name='bn1')(p1)
    
    p2 = layers.Dense(x_dim, use_bias=False,name='stem2')(inp2)
    p2 = layers.BatchNormalization(momentum=0.95,name='bn2')(p2)
    
    # ---------------------------------------------------
    # 2) Self Attention using LLaMA blocks
    # ---------------------------------------------------
    s1 = llama_self_attention(p1, mask1, conv_out, heads)
    s2 = llama_self_attention(p2, mask2, conv_out, heads)

    # ---------------------------------------------------
    # 3) Cross Attention (1→2 and 2→1)
    # ---------------------------------------------------
    c1 = cross_attention_block(p1, p2, p2, mask2, heads, d_dim)
    c2 = cross_attention_block(p2, p1, p1, mask1, heads, d_dim)

    # Add residual (same as PyTorch + skip)
    sc1 = layers.Add()([s1, c1])
    sc2 = layers.Add()([s2, c2])

    sc1 = layers.Dropout(drop_pool)(sc1)
    sc2 = layers.Dropout(drop_pool)(sc2)

    # ---------------------------------------------------
    # 4) Hybrid Pooling (max + mean)
    # ---------------------------------------------------
    h1 = HybridPooling()(sc1)
    h2 = HybridPooling()(sc2)

    merged = layers.Concatenate()([h1, h2])

    # ---------------------------------------------------
    # 5) MLP Head
    # ---------------------------------------------------
    x = layers.Dense(256, activation="relu")(merged)
    x = layers.Dropout(drop_linear)(x)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dropout(drop_linear)(x)
    out = layers.Dense(1, activation="sigmoid")(x)

    return Model(inputs=[inp1, inp2, mask1, mask2], outputs=out)




In [8]:
import os
import time
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, f1_score, recall_score, precision_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
import wandb
from wandb.integration.keras import WandbMetricsLogger

In [9]:
# =========================================================
# DEFAULT CONFIG (Logged Automatically to W&B)
# =========================================================
TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")

CONFIG = {
    "batch_size": 4,
    "n_splits": 5,
    "epochs": 10,
    "learning_rate": 5e-5,
    "optimizer": "Adam",
    "loss": "binary_crossentropy",
    "architecture": "ProtT5",
    "dataset": "Yeast",
    "task": "Prot-Prot Classification",
    "random_state": 42,
    "max_len":512,
}

PROJECT_NAME = f"{CONFIG['dataset']}-{CONFIG['architecture']}-{TIMESTAMP}"
OUT_PATH = os.path.join("weights",PROJECT_NAME)
os.makedirs(os.path.join(OUT_PATH, "logs"), exist_ok=True)
os.makedirs(os.path.join(OUT_PATH, "weights"), exist_ok=True)
OUT_PATH

'weights/Yeast-ProtT5-20260104-064702'

In [10]:
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, f1_score, recall_score, precision_score,
    roc_auc_score, matthews_corrcoef, confusion_matrix,
    average_precision_score
)
CONFIG['n_splits'] =5
# =============================================================
#  K-FOLD CROSS VALIDATION SETTINGS
# =============================================================
skf = StratifiedKFold(n_splits=CONFIG['n_splits'], shuffle=True, random_state=42)

all_metrics = []

# =============================================================
#  MAIN LOOP
# =============================================================
for fold,(train_idx, valid_idx) in enumerate(skf.split(df, df["Interaction"]),1):
    print(f"\n==========================")
    print(f" Fold {fold} / {CONFIG['n_splits']}")
    print(f"==========================")
    run = wandb.init(
            project=PROJECT_NAME,
            name=f"fold_{fold}",
            group="KFold-CV",
            config={**CONFIG, "fold": fold},
            reinit=True
        )
    train_df = df.iloc[train_idx].reset_index(drop=True)
    valid_df = df.iloc[valid_idx].reset_index(drop=True)

    # ---------------------------------------
    # Loaders
    # ---------------------------------------
    train_loader = DataSequenceLoader(
        train_df, 
        batch_size=CONFIG["batch_size"],
        max_len = CONFIG['max_len'], 
        shuffle=True)
    valid_loader = DataSequenceLoader(
        valid_df, 
        batch_size=CONFIG["batch_size"],
        max_len = CONFIG['max_len'], 
        shuffle=False)

    # ---------------------------------------
    # Build a FRESH MODEL per fold
    # ---------------------------------------
    model = build_model()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(CONFIG["learning_rate"]),
        loss=CONFIG["loss"],
        metrics=[
            "accuracy",
            tf.keras.metrics.AUC(name="auc")
        ]
    )
    # -----------------------------------------------------
    # Callbacks
    # -----------------------------------------------------
    tb_callback = tf.keras.callbacks.TensorBoard(
        log_dir=os.path.join(OUT_PATH, "logs", f"fold_{fold}")
    )

    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(
            OUT_PATH, "weights",
            f"weights_fold{fold}-best.weights.h5"
        ),
        save_weights_only=True,
        save_best_only=True,
        monitor="val_loss"
    )
    # ---------------------------------------
    # Train
    # ---------------------------------------
    history = model.fit(
        train_loader,
        validation_data=valid_loader,
        epochs=CONFIG["epochs"],
        callbacks=[
            tb_callback,
            checkpoint_cb,
            WandbMetricsLogger(log_freq="epoch")
        ],
        verbose=1
    )

    model.save_weights(
        os.path.join(
            OUT_PATH,
            "weights",
            f"weights_fold{fold}-last.weights.h5"
        )
    )

    # -----------------------------------------------------
    # Evaluation
    # -----------------------------------------------------
    y_pred_prob = model.predict(valid_loader).ravel()
    y_pred = (y_pred_prob > 0.5).astype(int)
    y_true = valid_df["Interaction"].values

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

    metrics_dict = {
        "fold": fold,
        "accuracy": accuracy_score(y_true, y_pred),
        "f1": f1_score(y_true, y_pred),
        "recall": recall_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred),
        "roc_auc": roc_auc_score(y_true, y_pred_prob),
        "aupr": average_precision_score(y_true, y_pred_prob),
        "specificity": specificity
    }

    # Log fold metrics to W&B
    wandb.log(metrics_dict)

    print(metrics_dict)
    all_metrics.append(metrics_dict)

    run.finish()

# =========================================================
# Save All Metrics
# =========================================================
metrics_df = pd.DataFrame(all_metrics)
avg_row = metrics_df.mean(numeric_only=True)
metrics_df = pd.concat(
    [metrics_df, avg_row.to_frame().T],
    ignore_index=True
)
metrics_df.loc[metrics_df.index[-1], "fold"] = "Average"

metrics_df.to_csv(
    os.path.join(
        OUT_PATH,
        f"{PROJECT_NAME}-kfold_classification_metrics.csv"
    ),
    index=False
)

print("\nAll fold metrics saved.")
metrics_df


 Fold 1 / 5


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhossainstudy7[0m ([33mhossainstudy7-freelancer[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


I0000 00:00:1767530826.439754   63792 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79078 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:81:00.0, compute capability: 8.0


Epoch 1/10


  self._warn_if_super_not_called()
I0000 00:00:1767530834.268663   64309 service.cc:152] XLA service 0x7ffe4c006e40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1767530834.268703   64309 service.cc:160]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
I0000 00:00:1767530835.311133   64309 cuda_dnn.cc:529] Loaded cuDNN version 90500


[1m   1/2238[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m170:48:01[0m 275s/step - accuracy: 0.5000 - auc: 0.6667 - loss: 1.0392

I0000 00:00:1767531103.944725   64309 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 137ms/step - accuracy: 0.6243 - auc: 0.6662 - loss: 0.8456

  self._warn_if_super_not_called()


[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m681s[0m 182ms/step - accuracy: 0.7288 - auc: 0.8006 - loss: 0.5844 - val_accuracy: 0.8865 - val_auc: 0.9587 - val_loss: 0.2939
Epoch 2/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 20ms/step - accuracy: 0.8934 - auc: 0.9561 - loss: 0.2672 - val_accuracy: 0.9267 - val_auc: 0.9805 - val_loss: 0.1946
Epoch 3/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 20ms/step - accuracy: 0.9317 - auc: 0.9792 - loss: 0.1811 - val_accuracy: 0.9352 - val_auc: 0.9859 - val_loss: 0.1752
Epoch 4/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 20ms/step - accuracy: 0.9512 - auc: 0.9886 - loss: 0.1312 - val_accuracy: 0.9580 - val_auc: 0.9918 - val_loss: 0.1431
Epoch 5/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 20ms/step - accuracy: 0.9592 - auc: 0.9931 - loss: 0.1029 - val_accuracy: 0.9620 - val_auc: 0.9908 - val_loss: 0.1161
Epoch 6/10
[1m2238



[1m556/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 12ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step
{'fold': 1, 'accuracy': 0.9597855227882037, 'f1': 0.9595687331536388, 'recall': 0.9544235924932976, 'precision': 0.964769647696477, 'roc_auc': 0.9915729686517947, 'aupr': 0.9927496573463946, 'specificity': 0.9651474530831099}


0,1
accuracy,▁
aupr,▁
epoch/accuracy,▁▆▇▇▇█████
epoch/auc,▁▇▇███████
epoch/epoch,▁▂▃▃▄▅▆▆▇█
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▄▃▂▂▁▁▁▁▁
epoch/val_accuracy,▁▄▅▇█▇▇▇█▇
epoch/val_auc,▁▅▇████▇██
epoch/val_loss,█▄▄▃▂▂▁▂▁▁

0,1
accuracy,0.95979
aupr,0.99275
epoch/accuracy,0.97777
epoch/auc,0.99724
epoch/epoch,9
epoch/learning_rate,5e-05
epoch/loss,0.05763
epoch/val_accuracy,0.95979
epoch/val_auc,0.99125
epoch/val_loss,0.11029



 Fold 2 / 5


Epoch 1/10


  self._warn_if_super_not_called()


[1m2236/2238[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 21ms/step - accuracy: 0.6152 - auc: 0.6510 - loss: 0.9282



[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 27ms/step - accuracy: 0.7232 - auc: 0.7902 - loss: 0.6148 - val_accuracy: 0.8852 - val_auc: 0.9548 - val_loss: 0.3133
Epoch 2/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 18ms/step - accuracy: 0.8920 - auc: 0.9558 - loss: 0.2677 - val_accuracy: 0.9276 - val_auc: 0.9765 - val_loss: 0.2198
Epoch 3/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 18ms/step - accuracy: 0.9326 - auc: 0.9821 - loss: 0.1686 - val_accuracy: 0.9383 - val_auc: 0.9826 - val_loss: 0.1729
Epoch 4/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 20ms/step - accuracy: 0.9484 - auc: 0.9885 - loss: 0.1329 - val_accuracy: 0.9383 - val_auc: 0.9842 - val_loss: 0.1689
Epoch 5/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 18ms/step - accuracy: 0.9569 - auc: 0.9919 - loss: 0.1102 - val_accuracy: 0.9477 - val_auc: 0.9875 - val_loss: 0.1377
Epoch 6/10
[1m2238/2



[1m558/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 12ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step
{'fold': 2, 'accuracy': 0.9548704200178731, 'f1': 0.955011135857461, 'recall': 0.9579982126899017, 'precision': 0.9520426287744227, 'roc_auc': 0.9905635137973471, 'aupr': 0.9922743097297222, 'specificity': 0.9517426273458445}


0,1
accuracy,▁
aupr,▁
epoch/accuracy,▁▆▇▇▇█████
epoch/auc,▁▇▇███████
epoch/epoch,▁▂▃▃▄▅▆▆▇█
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▄▂▂▂▁▁▁▁▁
epoch/val_accuracy,▁▅▆▆▇██▇▇█
epoch/val_auc,▁▅▇▇▇█████
epoch/val_loss,█▅▃▃▂▁▁▂▂▁

0,1
accuracy,0.95487
aupr,0.99227
epoch/accuracy,0.98078
epoch/auc,0.99815
epoch/epoch,9
epoch/learning_rate,5e-05
epoch/loss,0.05099
epoch/val_accuracy,0.95487
epoch/val_auc,0.99004
epoch/val_loss,0.12253



 Fold 3 / 5


Epoch 1/10


  self._warn_if_super_not_called()


[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.6281 - auc: 0.6592 - loss: 1.0007



[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 28ms/step - accuracy: 0.7242 - auc: 0.7839 - loss: 0.6424 - val_accuracy: 0.8601 - val_auc: 0.9539 - val_loss: 0.3312
Epoch 2/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 20ms/step - accuracy: 0.8916 - auc: 0.9543 - loss: 0.2710 - val_accuracy: 0.9209 - val_auc: 0.9776 - val_loss: 0.2102
Epoch 3/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 18ms/step - accuracy: 0.9326 - auc: 0.9799 - loss: 0.1784 - val_accuracy: 0.9428 - val_auc: 0.9877 - val_loss: 0.1535
Epoch 4/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 19ms/step - accuracy: 0.9454 - auc: 0.9880 - loss: 0.1371 - val_accuracy: 0.9187 - val_auc: 0.9854 - val_loss: 0.1949
Epoch 5/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 19ms/step - accuracy: 0.9584 - auc: 0.9925 - loss: 0.1070 - val_accuracy: 0.9486 - val_auc: 0.9903 - val_loss: 0.1353
Epoch 6/10
[1m2238/2



[1m558/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 12ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step
{'fold': 3, 'accuracy': 0.9432529043789097, 'f1': 0.9415016121602948, 'recall': 0.9133154602323503, 'precision': 0.9714828897338403, 'roc_auc': 0.9859842304623767, 'aupr': 0.9880279273420506, 'specificity': 0.9731903485254692}


0,1
accuracy,▁
aupr,▁
epoch/accuracy,▁▆▇▇▇█████
epoch/auc,▁▇▇███████
epoch/epoch,▁▂▃▃▄▅▆▆▇█
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▄▃▂▂▁▁▁▁▁
epoch/val_accuracy,▁▆█▆█▇████
epoch/val_auc,▁▆█▇█████▇
epoch/val_loss,█▄▂▃▁▃▁▁▁▃

0,1
accuracy,0.94325
aupr,0.98803
epoch/accuracy,0.98145
epoch/auc,0.99795
epoch/epoch,9
epoch/learning_rate,5e-05
epoch/loss,0.04851
epoch/val_accuracy,0.94325
epoch/val_auc,0.98328
epoch/val_loss,0.17404



 Fold 4 / 5


Epoch 1/10


  self._warn_if_super_not_called()


[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 140ms/step - accuracy: 0.6233 - auc: 0.6640 - loss: 0.9032



[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m448s[0m 193ms/step - accuracy: 0.7248 - auc: 0.7976 - loss: 0.5952 - val_accuracy: 0.8932 - val_auc: 0.9572 - val_loss: 0.2907
Epoch 2/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 19ms/step - accuracy: 0.8948 - auc: 0.9576 - loss: 0.2610 - val_accuracy: 0.8713 - val_auc: 0.9829 - val_loss: 0.2737
Epoch 3/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 18ms/step - accuracy: 0.9350 - auc: 0.9821 - loss: 0.1690 - val_accuracy: 0.9383 - val_auc: 0.9846 - val_loss: 0.1608
Epoch 4/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 18ms/step - accuracy: 0.9527 - auc: 0.9886 - loss: 0.1309 - val_accuracy: 0.9410 - val_auc: 0.9820 - val_loss: 0.1861
Epoch 5/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 19ms/step - accuracy: 0.9584 - auc: 0.9916 - loss: 0.1118 - val_accuracy: 0.9535 - val_auc: 0.9896 - val_loss: 0.1350
Epoch 6/10
[1m2238



[1m556/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 12ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step
{'fold': 4, 'accuracy': 0.957532409476978, 'f1': 0.957149300857014, 'recall': 0.9481680071492404, 'precision': 0.9663023679417122, 'roc_auc': 0.9921609346448801, 'aupr': 0.9932681013860887, 'specificity': 0.9669051878354203}


0,1
accuracy,▁
aupr,▁
epoch/accuracy,▁▆▇▇▇█████
epoch/auc,▁▇▇███████
epoch/epoch,▁▂▃▃▄▅▆▆▇█
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▄▂▂▂▁▁▁▁▁
epoch/val_accuracy,▃▁▆▆▇▇▇▇██
epoch/val_auc,▁▆▆▆▇▇▇███
epoch/val_loss,█▇▃▄▂▂▃▁▁▁

0,1
accuracy,0.95753
aupr,0.99327
epoch/accuracy,0.97821
epoch/auc,0.99676
epoch/epoch,9
epoch/learning_rate,5e-05
epoch/loss,0.06218
epoch/val_accuracy,0.95753
epoch/val_auc,0.99218
epoch/val_loss,0.10852



 Fold 5 / 5


Epoch 1/10


  self._warn_if_super_not_called()


[1m2236/2238[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 21ms/step - accuracy: 0.6179 - auc: 0.6486 - loss: 1.0247



[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 27ms/step - accuracy: 0.7203 - auc: 0.7786 - loss: 0.6481 - val_accuracy: 0.8789 - val_auc: 0.9464 - val_loss: 0.3185
Epoch 2/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 20ms/step - accuracy: 0.8877 - auc: 0.9540 - loss: 0.2744 - val_accuracy: 0.9133 - val_auc: 0.9802 - val_loss: 0.2295
Epoch 3/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 20ms/step - accuracy: 0.9275 - auc: 0.9796 - loss: 0.1815 - val_accuracy: 0.9428 - val_auc: 0.9833 - val_loss: 0.1657
Epoch 4/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 20ms/step - accuracy: 0.9517 - auc: 0.9891 - loss: 0.1287 - val_accuracy: 0.9477 - val_auc: 0.9864 - val_loss: 0.1502
Epoch 5/10
[1m2238/2238[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 20ms/step - accuracy: 0.9590 - auc: 0.9918 - loss: 0.1094 - val_accuracy: 0.9437 - val_auc: 0.9897 - val_loss: 0.1455
Epoch 6/10
[1m2238/2



[1m556/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 12ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step
{'fold': 5, 'accuracy': 0.951721054984354, 'f1': 0.9505041246562786, 'recall': 0.9275491949910555, 'precision': 0.974624060150376, 'roc_auc': 0.9902673131677433, 'aupr': 0.9916782160780181, 'specificity': 0.9758713136729222}


0,1
accuracy,▁
aupr,▁
epoch/accuracy,▁▅▇▇▇█████
epoch/auc,▁▇▇███████
epoch/epoch,▁▂▃▃▄▅▆▆▇█
epoch/learning_rate,▁▁▁▁▁▁▁▁▁▁
epoch/loss,█▄▃▂▂▁▁▁▁▁
epoch/val_accuracy,▁▄▇▇▇▆█▇██
epoch/val_auc,▁▆▇▇██████
epoch/val_loss,█▅▃▂▂▂▁▂▁▁

0,1
accuracy,0.95172
aupr,0.99168
epoch/accuracy,0.98246
epoch/auc,0.99828
epoch/epoch,9
epoch/learning_rate,5e-05
epoch/loss,0.04626
epoch/val_accuracy,0.95172
epoch/val_auc,0.98924
epoch/val_loss,0.12901



All fold metrics saved.


  metrics_df.loc[metrics_df.index[-1], "fold"] = "Average"


Unnamed: 0,fold,accuracy,f1,recall,precision,roc_auc,aupr,specificity
0,1.0,0.959786,0.959569,0.954424,0.96477,0.991573,0.99275,0.965147
1,2.0,0.95487,0.955011,0.957998,0.952043,0.990564,0.992274,0.951743
2,3.0,0.943253,0.941502,0.913315,0.971483,0.985984,0.988028,0.97319
3,4.0,0.957532,0.957149,0.948168,0.966302,0.992161,0.993268,0.966905
4,5.0,0.951721,0.950504,0.927549,0.974624,0.990267,0.991678,0.975871
5,Average,0.953432,0.952747,0.940291,0.965844,0.99011,0.9916,0.966571


In [13]:
# =============================================================
#  EVALUATION (POST-TRAINING)
# =============================================================
from sklearn.metrics import (
    accuracy_score, f1_score, recall_score, precision_score,
    roc_auc_score, matthews_corrcoef,
    confusion_matrix, average_precision_score
)

all_metrics = []
skf = StratifiedKFold(n_splits=CONFIG['n_splits'], shuffle=True, random_state=42)

for fold, (_, valid_idx) in enumerate(
        skf.split(df, df["Interaction"]), 1):

    print(f"\nEvaluating Fold {fold}")

    valid_df = df.iloc[valid_idx].reset_index(drop=True)

    valid_loader = DataSequenceLoader(
        valid_df,
        batch_size=CONFIG["batch_size"],
        max_len=CONFIG["max_len"],
        shuffle=False
    )

    # Rebuild model & load best weights
    model = build_model()
    model.load_weights(
        os.path.join(
            OUT_PATH, "weights",
            f"weights_fold{fold}-best.weights.h5"
        )
    )

    # Prediction
    y_prob = model.predict(valid_loader).ravel()
    y_pred = (y_prob >= 0.5).astype(int)
    y_true = valid_df["Interaction"].values

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    metrics_dict = {
        "fold": fold,
        "accuracy": accuracy_score(y_true, y_pred),
        "f1": f1_score(y_true, y_pred),
        "recall": recall_score(y_true, y_pred),     # sensitivity
        "precision": precision_score(y_true, y_pred),
        "mcc": matthews_corrcoef(y_true, y_pred),
        "auc": roc_auc_score(y_true, y_prob),
        "prauc": average_precision_score(y_true, y_prob),
        "specificity": tn / (tn + fp) if (tn + fp) > 0 else 0.0,
        "sensitivity": tp / (tp + fn) if (tp + fn) > 0 else 0.0
    }

    print(metrics_dict)
    all_metrics.append(metrics_dict)
metrics_df = pd.DataFrame(all_metrics)

avg_row = metrics_df.mean(numeric_only=True)
avg_row["fold"] = "Average"

metrics_df = pd.concat(
    [metrics_df, avg_row.to_frame().T],
    ignore_index=True
)

metrics_df.to_csv(
    os.path.join(
        OUT_PATH,
        f"{PROJECT_NAME}-kfold_evaluation_metrics.csv"
    ),
    index=False
)

print("\nEvaluation completed and saved.")
metrics_df



Evaluating Fold 1


  self._warn_if_super_not_called()


[1m556/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 11ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step
{'fold': 1, 'accuracy': 0.9673815907059875, 'f1': 0.9669234254644313, 'recall': 0.9535299374441466, 'precision': 0.9806985294117647, 'mcc': 0.9351220908566883, 'auc': 0.9929841290377196, 'prauc': 0.9940483459145044, 'specificity': 0.9812332439678284, 'sensitivity': 0.9535299374441466}

Evaluating Fold 2


  self._warn_if_super_not_called()


[1m556/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 11ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 17ms/step
{'fold': 2, 'accuracy': 0.9548704200178731, 'f1': 0.955011135857461, 'recall': 0.9579982126899017, 'precision': 0.9520426287744227, 'mcc': 0.909758640705754, 'auc': 0.9905635137973471, 'prauc': 0.9922743097297222, 'specificity': 0.9517426273458445, 'sensitivity': 0.9579982126899017}

Evaluating Fold 3


  self._warn_if_super_not_called()


[1m558/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 11ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 17ms/step
{'fold': 3, 'accuracy': 0.9481680071492404, 'f1': 0.9467889908256881, 'recall': 0.9222520107238605, 'precision': 0.9726672950047125, 'mcc': 0.8975424739759503, 'auc': 0.9897481234441896, 'prauc': 0.9912045097618505, 'specificity': 0.9740840035746202, 'sensitivity': 0.9222520107238605}

Evaluating Fold 4


  self._warn_if_super_not_called()


[1m555/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 13ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step
{'fold': 4, 'accuracy': 0.957532409476978, 'f1': 0.957149300857014, 'recall': 0.9481680071492404, 'precision': 0.9663023679417122, 'mcc': 0.9152268379237681, 'auc': 0.9921609346448801, 'prauc': 0.9932681013860887, 'specificity': 0.9669051878354203, 'sensitivity': 0.9481680071492404}

Evaluating Fold 5


  self._warn_if_super_not_called()


[1m556/560[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 11ms/step



[1m560/560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 17ms/step
{'fold': 5, 'accuracy': 0.9530621367903442, 'f1': 0.9533540648600622, 'recall': 0.9597495527728086, 'precision': 0.9470432480141218, 'mcc': 0.9062063109440174, 'auc': 0.9914455310053539, 'prauc': 0.9926356483461605, 'specificity': 0.9463806970509383, 'sensitivity': 0.9597495527728086}

Evaluation completed and saved.


  avg_row["fold"] = "Average"


Unnamed: 0,fold,accuracy,f1,recall,precision,mcc,auc,prauc,specificity,sensitivity
0,1,0.967382,0.966923,0.95353,0.980699,0.935122,0.992984,0.994048,0.981233,0.95353
1,2,0.95487,0.955011,0.957998,0.952043,0.909759,0.990564,0.992274,0.951743,0.957998
2,3,0.948168,0.946789,0.922252,0.972667,0.897542,0.989748,0.991205,0.974084,0.922252
3,4,0.957532,0.957149,0.948168,0.966302,0.915227,0.992161,0.993268,0.966905,0.948168
4,5,0.953062,0.953354,0.95975,0.947043,0.906206,0.991446,0.992636,0.946381,0.95975
5,Average,0.956203,0.955845,0.94834,0.963751,0.912771,0.99138,0.992686,0.964069,0.94834
