In [1]:
import os
os.environ["KERAS_BACKEND"] ='tensorflow'
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
from tqdm.auto import tqdm
tqdm.pandas()
import keras 
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
from keras import layers, Model, Input
from keras_hub.src.models.llama.llama_decoder import LlamaTransformerDecoder
from keras_hub.src.models.llama.llama_layernorm import LlamaLayerNorm

  from .autonotebook import tqdm as notebook_tqdm
2025-12-04 14:53:16.482746: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764881596.495679  118629 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764881596.499885  118629 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Data Loading

In [2]:
import h5py
DATA_PATH = '../RLEAAI/data/HIV/'
loaded_data = {}
with h5py.File(os.path.join(DATA_PATH,"HIV-RLEAAI-ProtT5-Full.h5"), 'r') as hf:
    for seq in hf.keys():
        loaded_data[seq] = hf[seq][:]

In [3]:
df = pd.read_excel("https://github.com/zhouyu9931/RLEAAI/raw/refs/heads/main/data/dataset_hiv.xlsx")
df.head()

Unnamed: 0,antibody_seq,virus_seq,label,split
0,QMKLMQSGGVMVRPGESATLSCVASGFDFSRNGFEWLRQGPGKGLQ...,MRVMGIRKNYQHLWREGILLLGILMICSAADNLWVTVYYGVPVWRE...,0,seen
1,QPQLQESGPGLVEASETLSLTCTVSGDSTGRCNYFWGWVRQPPGKG...,MRVRGIPRNWPQWWIWGILGFWMIIICRVVGNMWVTVYYGVPVWTD...,0,seen
2,QVQLLQSGAAVTKPGASVRVSCEASGYNIRDYFIHWWRQAPGQGLQ...,MRVMEIQRNCQHWWIWGILGFWMLMICNVRGWWVTVYYGVPVWKEA...,1,seen
3,QSQLQESGPRLVEASETLSLTCNVSGESTGACTYFWGWVRQAPGKG...,MRVKETQMNWPNLWKLGTLILGLVIICSASXNLWVTVYYGVPVWRD...,1,seen
4,QEQLVESGGGVVQPGGSLRLSCLASGFTFHKYGMHWVRQAPGKGLE...,MRVTGTQRNCQQWWIWIWIILGFWWMLMMCKGEKLWVTVYYGVPVW...,1,seen


In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import Sequence

def process_sequence_tf(x_emb, max_len=1024, pad_value=0.0):
    seq_len = x_emb.shape[0]
    
    if seq_len > max_len:
        # truncate
        x_emb = tf.convert_to_tensor(x_emb[:max_len])
        mask = tf.ones([max_len], dtype=tf.float32)
    else:
        # pad
        pad_len = max_len - seq_len
        paddings = [[0, pad_len], [0, 0]]
        x_emb = tf.pad(x_emb, paddings, constant_values=pad_value)
        mask = tf.pad(tf.ones([seq_len], dtype=tf.float32), [[0, pad_len]], constant_values=0.0)
    
    return x_emb, mask

# -----------------------------
# Keras Sequence Loader
# -----------------------------
class DataSequenceLoader(Sequence):
    def __init__(self, df, batch_size=32, shuffle=True, max_len=512, pad_value=0.0):
        self.x1_emb = df["antibody_seq"].values
        self.x2_emb = df["virus_seq"].values
        self.labels = df["label"].values
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.max_len = max_len
        self.pad_value = pad_value
        self.indices = np.arange(len(df))
        if shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        return int(np.ceil(len(self.indices) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __getitem__(self, idx):
        batch_idx = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]

        x1_list, x2_list, m1_list, m2_list = [], [], [], []
        labels_list = []

        for i in batch_idx:
            emb1 = loaded_data[self.x1_emb[i]].reshape(-1, 1024)
            emb2 = loaded_data[self.x2_emb[i]].reshape(-1, 1024)

            # ---- process embeddings + mask ----
            x1_pad, mask1 = process_sequence_tf(emb1, max_len=self.max_len, pad_value=self.pad_value)
            x2_pad, mask2 = process_sequence_tf(emb2, max_len=self.max_len, pad_value=self.pad_value)

            x1_list.append(x1_pad.numpy())
            x2_list.append(x2_pad.numpy())
            m1_list.append(mask1.numpy())
            m2_list.append(mask2.numpy())
            labels_list.append(self.labels[i])

        # Convert lists to arrays for batch
        x1_batch = np.stack(x1_list, axis=0)
        x2_batch = np.stack(x2_list, axis=0)
        m1_batch = np.stack(m1_list, axis=0)
        m2_batch = np.stack(m2_list, axis=0)
        labels_batch = np.array(labels_list)

        return (x1_batch, x2_batch, m1_batch, m2_batch), labels_batch


In [5]:
# loader  = DataSequenceLoader(df,batch_size=4,shuffle =True)
# for x in loader:
#     break

In [6]:
# -------------------------------------------------------------------
# Hybrid Pooling Layer (Max + Avg)
# -------------------------------------------------------------------
class HybridPooling(layers.Layer):
    def call(self, x):
        max_pooled = keras.ops.max(x, axis=1)
        avg_pooled = keras.ops.mean(x, axis=1)
        return keras.ops.concatenate([max_pooled, avg_pooled], axis=-1)


# -------------------------------------------------------------------
# Conv Block (Conv1D → ReLU → Dropout → MaxPool)
# -------------------------------------------------------------------
def conv_block(x, filters =  100, kernel_sz =20, stride =10, dropout = 0.5):
    x = layers.Conv1D(filters, kernel_sz, strides=stride)(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.MaxPooling1D(pool_size=3, strides=1, padding="same")(x)
    return x


# -------------------------------------------------------------------
# Cross-attention (Query=X1, Key=X2, Value=X2)
# -------------------------------------------------------------------
def cross_attention_block(query, key, value, mask, num_heads =4, key_dim =32):
    attn = layers.MultiHeadAttention(
        num_heads=num_heads, 
        key_dim=key_dim,
        dropout=0.0,
        kernel_initializer="glorot_uniform",
        use_bias=True,
        # flash_attention=None,  # if GPU supports
    )(query, key, value, key_mask=mask, value_mask=mask)
    return attn


# -------------------------------------------------------------------
# LLaMA Self-Attention Block
# -------------------------------------------------------------------
def llama_self_attention(x, mask, hidden_dim =100, num_heads = 4):
    # single LLaMA decoder layer
    llama = LlamaTransformerDecoder(
        intermediate_dim=hidden_dim * 4,
        # num_heads=num_heads,
        num_query_heads=8,
        num_key_value_heads=2,
        dropout=0.0,
        layer_norm_epsilon=1e-5,
        activation="silu"
    )
    return llama(x, decoder_padding_mask=mask)

# -------------------------------------------------------------------
# Functional API Model (PyTorch → Keras Conversion)
# -------------------------------------------------------------------
def build_model(
    input_dim=1024,
    conv_out=100,
    kernel_sz=20,
    stride=10,
    heads=4,
    d_dim=32,
    drop_pool=0.4,
    drop_linear=0.4
):
    # Inputs
    inp1 = Input((None, 1024))
    inp2 = Input((None, 1024))
    mask1 = Input((None, ))
    mask2 = Input((None, ))

    # ---------------------------------------------------
    # 1) Convolutional features
    # ---------------------------------------------------
    # p1 = conv_block(inp1, conv_out, kernel_sz, stride, drop_pool)
    # p2 = conv_block(inp2, conv_out, kernel_sz, stride, drop_pool)
    x_dim = 384
    p1 = layers.Dense(x_dim, use_bias=False,name='stem1')(inp1)
    p1 = layers.BatchNormalization(momentum=0.95,name='bn1')(p1)
    
    p2 = layers.Dense(x_dim, use_bias=False,name='stem2')(inp2)
    p2 = layers.BatchNormalization(momentum=0.95,name='bn2')(p2)
    
    # ---------------------------------------------------
    # 2) Self Attention using LLaMA blocks
    # ---------------------------------------------------
    s1 = llama_self_attention(p1, mask1, conv_out, heads)
    s2 = llama_self_attention(p2, mask2, conv_out, heads)

    # ---------------------------------------------------
    # 3) Cross Attention (1→2 and 2→1)
    # ---------------------------------------------------
    c1 = cross_attention_block(p1, p2, p2, mask2, heads, d_dim)
    c2 = cross_attention_block(p2, p1, p1, mask1, heads, d_dim)

    # Add residual (same as PyTorch + skip)
    sc1 = layers.Add()([s1, c1])
    sc2 = layers.Add()([s2, c2])

    sc1 = layers.Dropout(drop_pool)(sc1)
    sc2 = layers.Dropout(drop_pool)(sc2)

    # ---------------------------------------------------
    # 4) Hybrid Pooling (max + mean)
    # ---------------------------------------------------
    h1 = HybridPooling()(sc1)
    h2 = HybridPooling()(sc2)

    merged = layers.Concatenate()([h1, h2])

    # ---------------------------------------------------
    # 5) MLP Head
    # ---------------------------------------------------
    x = layers.Dense(256, activation="relu")(merged)
    x = layers.Dropout(drop_linear)(x)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dropout(drop_linear)(x)
    out = layers.Dense(1, activation="sigmoid")(x)

    return Model(inputs=[inp1, inp2, mask1, mask2], outputs=out)

In [13]:
import gc
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, recall_score, precision_score, f1_score, roc_auc_score,
    average_precision_score, matthews_corrcoef, confusion_matrix
)

# ================================================================
# Setup
# ================================================================
train_df,unseen_df = df[df['split']=='seen'].copy(),  df[df['split']=='unseen'].copy()

X = train_df.index.values  # DataSequenceLoader loads using df, so features = df itself
y = train_df['label'].values
bs =32
unseen_loader  = DataSequenceLoader(unseen_df, batch_size=bs, shuffle=False)

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

fold_results = []


# ================================================================
# 5-Fold Training Loop
# ================================================================
for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)):
    WEIGHT_PATH = f"../weights/LlamaCrossAttn_HIV-RLEAAI/model_fold_{fold}.weights.h5"
    os.makedirs(os.path.dirname(WEIGHT_PATH),exist_ok=True)
    print(f"\n=========== FOLD {fold+1} ===========")

    train_df = df.iloc[train_idx]
    test_df  = df.iloc[test_idx]
    bs =32
    # ---- Loaders ----
    train_loader = DataSequenceLoader(train_df, batch_size=bs, shuffle=True)
    test_loader  = DataSequenceLoader(test_df, batch_size=bs, shuffle=False)

    # ---- Build Model ----
    model = build_model()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    # ---- Train ----
    model.fit(train_loader,  validation_data=test_loader, epochs=100, verbose=1)

    model.save_weights(WEIGHT_PATH)
    # ---- Predict ----
    y_true =unseen_df['label'].values
    y_prob = model.predict(unseen_loader)
    y_pred = (y_prob > 0.5).astype(int)

    # ---- Confusion Matrix ----
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    # ---- Metrics ----
    ACC  = accuracy_score(y_true, y_pred)
    SN   = recall_score(y_true, y_pred)
    SP   = tn / (tn + fp) if (tn + fp) > 0 else 0
    MCC  = matthews_corrcoef(y_true, y_pred)
    PREC = precision_score(y_true, y_pred)
    NPV  = tn / (tn + fn) if (tn + fn) > 0 else 0
    F1   = f1_score(y_true, y_pred)
    AUC  = roc_auc_score(y_true, y_prob)
    AUPR = average_precision_score(y_true, y_prob)

    fold_metrics = {
        "ACC": ACC,
        "SN": SN,
        "SP": SP,
        "MCC": MCC,
        "Precision": PREC,
        "NPV": NPV,
        "F1": F1,
        "AUC": AUC,
        "AUPRC": AUPR
    }

    fold_results.append(fold_metrics)

    # ---- Print fold results ----
    display(pd.DataFrame([fold_metrics]))

    # ---- Cleanup model + loaders ----
    del model
    del train_loader
    del test_loader
    gc.collect()


# ================================================================
# Final results: per fold + mean
# ================================================================
results_df = pd.DataFrame(fold_results)

print("\n=========== FINAL 5-FOLD RESULTS ===========")
display(results_df)

print("\n=========== MEAN METRICS ===========")
display(results_df.mean())





I0000 00:00:1764835330.435296    8819 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79078 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:81:00.0, compute capability: 8.0
  self._warn_if_super_not_called()


Epoch 1/100


I0000 00:00:1764835340.199419   13582 service.cc:148] XLA service 0x7ffe74001e90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1764835340.200734   13582 service.cc:156]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
I0000 00:00:1764835341.304179   13582 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m  1/622[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m34:00:42[0m 197s/step - accuracy: 0.5625 - loss: 2.3832

I0000 00:00:1764835531.688784   13582 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 371ms/step - accuracy: 0.5163 - loss: 1.5295



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m529s[0m 534ms/step - accuracy: 0.5163 - loss: 1.5287 - val_accuracy: 0.6965 - val_loss: 0.6416
Epoch 2/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.6226 - loss: 0.6527 - val_accuracy: 0.7126 - val_loss: 0.6199
Epoch 3/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.6621 - loss: 0.6168 - val_accuracy: 0.7301 - val_loss: 0.5799
Epoch 4/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.6959 - loss: 0.5805 - val_accuracy: 0.7388 - val_loss: 0.5636
Epoch 5/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.7117 - loss: 0.5659 - val_accuracy: 0.7410 - val_loss: 0.5459
Epoch 6/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.7089 - loss: 0.5616 - val_accuracy: 0.7498 - val_loss: 0.5429
Epoch 7/100
[1



[1m142/143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 93ms/step



[1m143/143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 512ms/step


Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.8686,0.902365,0.842434,0.739229,0.816113,0.917587,0.857075,0.95273,0.939776



Epoch 1/100


  self._warn_if_super_not_called()


[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 116ms/step - accuracy: 0.5184 - loss: 1.5236



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 143ms/step - accuracy: 0.5184 - loss: 1.5229 - val_accuracy: 0.6184 - val_loss: 0.6547
Epoch 2/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.6091 - loss: 0.6670 - val_accuracy: 0.6949 - val_loss: 0.6091
Epoch 3/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.6799 - loss: 0.6009 - val_accuracy: 0.6941 - val_loss: 0.5925
Epoch 4/100
[1m357/622[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m26s[0m 101ms/step - accuracy: 0.6947 - loss: 0.5846

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 128ms/step - accuracy: 0.7406 - loss: 0.5270 - val_accuracy: 0.7428 - val_loss: 0.5322
Epoch 9/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.7399 - loss: 0.5202 - val_accuracy: 0.7408 - val_loss: 0.5302
Epoch 10/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.7435 - loss: 0.5165 - val_accuracy: 0.7448 - val_loss: 0.5264
Epoch 11/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.7562 - loss: 0.5046 - val_accuracy: 0.7613 - val_loss: 0.5143
Epoch 13/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.7521 - loss: 0.5032 - val_accuracy: 0.7569 - val_loss: 0.5096
Epoch 14/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 122ms/step - accuracy: 0.7587 - loss: 0.4980 - val_accuracy: 0.7593 - val_loss: 0.5025
Epoch 15/100

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 127ms/step - accuracy: 0.7704 - loss: 0.4776 - val_accuracy: 0.7680 - val_loss: 0.4863
Epoch 18/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.7784 - loss: 0.4760 - val_accuracy: 0.7657 - val_loss: 0.4925
Epoch 19/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.7699 - loss: 0.4761 - val_accuracy: 0.7730 - val_loss: 0.4793
Epoch 20/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.7719 - loss: 0.4769 - val_accuracy: 0.7825 - val_loss: 0.4748
Epoch 21/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.7784 - loss: 0.4619 - val_accuracy: 0.7776 - val_loss: 0.4772
Epoch 22/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 121ms/step - accuracy: 0.7744 - loss: 0.4708 - val_accuracy: 0.7766 - val_loss: 0.4765
Epoch 23/10

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 127ms/step - accuracy: 0.8007 - loss: 0.4247 - val_accuracy: 0.8026 - val_loss: 0.4435
Epoch 35/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.8008 - loss: 0.4257 - val_accuracy: 0.7947 - val_loss: 0.4417
Epoch 36/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 126ms/step - accuracy: 0.8069 - loss: 0.4209 - val_accuracy: 0.8054 - val_loss: 0.4389
Epoch 37/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.8080 - loss: 0.4142 - val_accuracy: 0.7957 - val_loss: 0.4395
Epoch 38/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 126ms/step - accuracy: 0.8102 - loss: 0.4133 - val_accuracy: 0.8094 - val_loss: 0.4251
Epoch 39/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.8048 - loss: 0.4178 - val_accuracy: 0.8066 - val_loss: 0.4261
Epoch 40/10

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.7369 - loss: 0.5186 - val_accuracy: 0.7593 - val_loss: 0.5098
Epoch 13/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.7577 - loss: 0.4981 - val_accuracy: 0.7631 - val_loss: 0.4968
Epoch 14/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.7577 - loss: 0.5017 - val_accuracy: 0.7706 - val_loss: 0.4980
Epoch 15/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 122ms/step - accuracy: 0.7581 - loss: 0.4928 - val_accuracy: 0.7692 - val_loss: 0.4869
Epoch 16/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.7642 - loss: 0.4923 - val_accuracy: 0.7712 - val_loss: 0.4902
Epoch 17/100
[1m223/622[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m40s[0m 100ms/step - accuracy: 0.7726 - loss: 0.4854

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 128ms/step - accuracy: 0.8107 - loss: 0.4077 - val_accuracy: 0.8116 - val_loss: 0.4093
Epoch 41/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.8127 - loss: 0.4086 - val_accuracy: 0.8151 - val_loss: 0.4137
Epoch 42/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 127ms/step - accuracy: 0.8141 - loss: 0.4040 - val_accuracy: 0.8157 - val_loss: 0.4128
Epoch 43/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 127ms/step - accuracy: 0.8096 - loss: 0.4119 - val_accuracy: 0.8177 - val_loss: 0.4072
Epoch 44/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.8130 - loss: 0.4086 - val_accuracy: 0.8191 - val_loss: 0.4016
Epoch 45/100
[1m307/622[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m32s[0m 102ms/step - accuracy: 0.8144 - loss: 0.4051

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 126ms/step - accuracy: 0.8414 - loss: 0.3551 - val_accuracy: 0.8344 - val_loss: 0.3679
Epoch 68/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 136ms/step - accuracy: 0.8422 - loss: 0.3566 - val_accuracy: 0.8374 - val_loss: 0.3636
Epoch 69/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 127ms/step - accuracy: 0.8450 - loss: 0.3488 - val_accuracy: 0.8372 - val_loss: 0.3672
Epoch 70/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 134ms/step - accuracy: 0.8392 - loss: 0.3550 - val_accuracy: 0.8368 - val_loss: 0.3638
Epoch 71/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 126ms/step - accuracy: 0.8419 - loss: 0.3573 - val_accuracy: 0.8334 - val_loss: 0.3662
Epoch 72/100
[1m255/622[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m41s[0m 112ms/step - accuracy: 0.8464 - loss: 0.3407

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 128ms/step - accuracy: 0.8598 - loss: 0.3158 - val_accuracy: 0.8370 - val_loss: 0.3542
Epoch 96/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.8655 - loss: 0.3084 - val_accuracy: 0.8422 - val_loss: 0.3440
Epoch 97/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.8611 - loss: 0.3132 - val_accuracy: 0.8448 - val_loss: 0.3471
Epoch 98/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.8625 - loss: 0.3133 - val_accuracy: 0.8418 - val_loss: 0.3427
Epoch 99/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 126ms/step - accuracy: 0.8566 - loss: 0.3125 - val_accuracy: 0.8418 - val_loss: 0.3455
Epoch 100/100
[1m281/622[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m34s[0m 102ms/step - accuracy: 0.8645 - loss: 0.3009

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 122ms/step - accuracy: 0.7748 - loss: 0.4705 - val_accuracy: 0.7880 - val_loss: 0.4749
Epoch 23/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.7746 - loss: 0.4676 - val_accuracy: 0.7888 - val_loss: 0.4669
Epoch 24/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 122ms/step - accuracy: 0.7877 - loss: 0.4538 - val_accuracy: 0.7977 - val_loss: 0.4714
Epoch 25/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 133ms/step - accuracy: 0.7817 - loss: 0.4540 - val_accuracy: 0.7895 - val_loss: 0.4652
Epoch 26/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.7872 - loss: 0.4478 - val_accuracy: 0.7995 - val_loss: 0.4664
Epoch 27/100
[1m359/622[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m29s[0m 110ms/step - accuracy: 0.7809 - loss: 0.4567

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.8185 - loss: 0.3947 - val_accuracy: 0.8249 - val_loss: 0.3998
Epoch 50/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 136ms/step - accuracy: 0.8214 - loss: 0.3936 - val_accuracy: 0.8295 - val_loss: 0.3976
Epoch 51/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.8257 - loss: 0.3839 - val_accuracy: 0.8313 - val_loss: 0.3942
Epoch 52/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 134ms/step - accuracy: 0.8207 - loss: 0.3839 - val_accuracy: 0.8253 - val_loss: 0.4054
Epoch 53/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 134ms/step - accuracy: 0.8243 - loss: 0.3873 - val_accuracy: 0.8227 - val_loss: 0.3948
Epoch 54/100
[1m459/622[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m16s[0m 99ms/step - accuracy: 0.8170 - loss: 0.3943

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 153ms/step - accuracy: 0.8355 - loss: 0.3583 - val_accuracy: 0.8426 - val_loss: 0.3689
Epoch 69/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 127ms/step - accuracy: 0.8344 - loss: 0.3724 - val_accuracy: 0.8432 - val_loss: 0.3653
Epoch 70/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 124ms/step - accuracy: 0.8418 - loss: 0.3522 - val_accuracy: 0.8394 - val_loss: 0.3720
Epoch 71/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 125ms/step - accuracy: 0.8416 - loss: 0.3542 - val_accuracy: 0.8380 - val_loss: 0.3719
Epoch 72/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.8387 - loss: 0.3521 - val_accuracy: 0.8462 - val_loss: 0.3682
Epoch 73/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.8393 - loss: 0.3522 - val_accuracy: 0.8480 - val_loss: 0.3601
Epoch 74/10



[1m142/143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 87ms/step



[1m143/143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 97ms/step


Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.872555,0.874182,0.871295,0.742585,0.840348,0.899356,0.856931,0.950831,0.938654



Epoch 1/100


  self._warn_if_super_not_called()


[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 111ms/step - accuracy: 0.5182 - loss: 1.5910



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m98s[0m 138ms/step - accuracy: 0.5182 - loss: 1.5902 - val_accuracy: 0.6995 - val_loss: 0.6446
Epoch 2/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 123ms/step - accuracy: 0.6077 - loss: 0.6633 - val_accuracy: 0.7156 - val_loss: 0.6040
Epoch 3/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 121ms/step - accuracy: 0.6712 - loss: 0.6077 - val_accuracy: 0.7381 - val_loss: 0.5762
Epoch 4/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 122ms/step - accuracy: 0.7008 - loss: 0.5772 - val_accuracy: 0.7470 - val_loss: 0.5548
Epoch 5/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 121ms/step - accuracy: 0.7178 - loss: 0.5636 - val_accuracy: 0.7510 - val_loss: 0.5498
Epoch 6/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 121ms/step - accuracy: 0.7205 - loss: 0.5472 - val_accuracy: 0.7564 - val_loss: 0.5424
Epoch 7/100
[1m

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 127ms/step - accuracy: 0.8257 - loss: 0.3822 - val_accuracy: 0.8279 - val_loss: 0.4054
Epoch 51/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.8236 - loss: 0.3900 - val_accuracy: 0.8273 - val_loss: 0.3958
Epoch 52/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 121ms/step - accuracy: 0.8277 - loss: 0.3873 - val_accuracy: 0.8237 - val_loss: 0.3983
Epoch 53/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 122ms/step - accuracy: 0.8269 - loss: 0.3841 - val_accuracy: 0.8283 - val_loss: 0.3979
Epoch 54/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 122ms/step - accuracy: 0.8316 - loss: 0.3764 - val_accuracy: 0.8287 - val_loss: 0.3957
Epoch 55/100
[1m622/622[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 123ms/step - accuracy: 0.8214 - loss: 0.3917 - val_accuracy: 0.8331 - val_loss: 0.3879
Epoch 56/10

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [7]:
import gc
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, recall_score, precision_score, f1_score, roc_auc_score,
    average_precision_score, matthews_corrcoef, confusion_matrix
)

# ================================================================
# Setup
# ================================================================
train_df,unseen_df = df[df['split']=='seen'].copy(),  df[df['split']=='unseen'].copy()

X = train_df.index.values  # DataSequenceLoader loads using df, so features = df itself
y = train_df['label'].values
bs =32
unseen_loader  = DataSequenceLoader(unseen_df, batch_size=bs, shuffle=False)

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

fold_results = []


# ================================================================
# 5-Fold Training Loop
# ================================================================
for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)):
    WEIGHT_PATH = f"../weights/LlamaCrossAttn_HIV-RLEAAI/model_fold_{fold}.weights.h5"
    os.makedirs(os.path.dirname(WEIGHT_PATH),exist_ok=True)
    print(f"\n=========== FOLD {fold+1} ===========")

    train_df = df.iloc[train_idx]
    test_df  = df.iloc[test_idx]
    bs =32
    # ---- Loaders ----
    train_loader = DataSequenceLoader(train_df, batch_size=bs, shuffle=True)
    test_loader  = DataSequenceLoader(test_df, batch_size=bs, shuffle=False)

    # ---- Build Model ----
    model = build_model()
    model.load_weights(WEIGHT_PATH)
    # ---- Predict ----
    y_true =unseen_df['label'].values
    y_prob = model.predict(unseen_loader)
    y_pred = (y_prob > 0.5).astype(int)

    # ---- Confusion Matrix ----
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    # ---- Metrics ----
    ACC  = accuracy_score(y_true, y_pred)
    SN   = recall_score(y_true, y_pred)
    SP   = tn / (tn + fp) if (tn + fp) > 0 else 0
    MCC  = matthews_corrcoef(y_true, y_pred)
    PREC = precision_score(y_true, y_pred)
    NPV  = tn / (tn + fn) if (tn + fn) > 0 else 0
    F1   = f1_score(y_true, y_pred)
    AUC  = roc_auc_score(y_true, y_prob)
    AUPR = average_precision_score(y_true, y_prob)

    fold_metrics = {
        "ACC": ACC,
        "SN": SN,
        "SP": SP,
        "MCC": MCC,
        "Precision": PREC,
        "NPV": NPV,
        "F1": F1,
        "AUC": AUC,
        "AUPRC": AUPR
    }

    fold_results.append(fold_metrics)

    # ---- Print fold results ----
    display(pd.DataFrame([fold_metrics]))

    # ---- Cleanup model + loaders ----
    del model
    del train_loader
    del test_loader
    gc.collect()


# ================================================================
# Final results: per fold + mean
# ================================================================
results_df = pd.DataFrame(fold_results)

print("\n=========== FINAL 5-FOLD RESULTS ===========")
display(results_df)

print("\n=========== MEAN METRICS ===========")
display(results_df.mean())





I0000 00:00:1764881659.861414  118629 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79078 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0000:81:00.0, compute capability: 8.0
  self._warn_if_super_not_called()
I0000 00:00:1764881663.937401  118956 service.cc:148] XLA service 0x7ffe740029d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1764881663.937476  118956 service.cc:156]   StreamExecutor device (0): NVIDIA A100 80GB PCIe, Compute Capability 8.0
I0000 00:00:1764881664.370173  118956 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m  3/143[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m8s[0m 61ms/step    

I0000 00:00:1764881712.840687  118956 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m142/143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 91ms/step



[1m143/143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 533ms/step


Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.8686,0.902365,0.842434,0.739229,0.816113,0.917587,0.857075,0.95273,0.939776







[1m142/143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 92ms/step



[1m143/143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 104ms/step


Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.878488,0.821842,0.922387,0.752657,0.891376,0.869805,0.855198,0.953557,0.940023







[1m142/143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 89ms/step



[1m143/143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 98ms/step


Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.878488,0.858078,0.894306,0.752861,0.862854,0.890485,0.860459,0.955961,0.9429







[1m142/143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 89ms/step



[1m143/143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 99ms/step


Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.872555,0.874182,0.871295,0.742585,0.840348,0.899356,0.856931,0.950831,0.938654







[1m142/143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 89ms/step



[1m143/143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 99ms/step


Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.874533,0.860594,0.885335,0.745233,0.853293,0.891245,0.856928,0.951662,0.937467





Unnamed: 0,ACC,SN,SP,MCC,Precision,NPV,F1,AUC,AUPRC
0,0.8686,0.902365,0.842434,0.739229,0.816113,0.917587,0.857075,0.95273,0.939776
1,0.878488,0.821842,0.922387,0.752657,0.891376,0.869805,0.855198,0.953557,0.940023
2,0.878488,0.858078,0.894306,0.752861,0.862854,0.890485,0.860459,0.955961,0.9429
3,0.872555,0.874182,0.871295,0.742585,0.840348,0.899356,0.856931,0.950831,0.938654
4,0.874533,0.860594,0.885335,0.745233,0.853293,0.891245,0.856928,0.951662,0.937467





ACC          0.874533
SN           0.863412
SP           0.883151
MCC          0.746513
Precision    0.852797
NPV          0.893696
F1           0.857318
AUC          0.952948
AUPRC        0.939764
dtype: float64