In [1]:
import os
import tqdm
import pandas as pd
import numpy as np

from modules.data import PocketDataset, Dataloader
from modules.Predictor import Pseq2SitesPredictor
from modules.helpers import convert_bs, get_results

In [2]:
config = {
    "paths": {
        "input_dir_embedding": os.path.join("data_embeddings", "pdbbind"),
        "data_path": os.path.join("data_preprocessed", "PDBbind_data.tsv"),
        "save_path": os.path.join("outputs", "pdbbind", "CV"),
        "model_path": os.path.join("outputs", "pdbbind", "model"),
        "result_path": os.path.join("outputs", "pdbbind", "prediction.tsv")
    },
    "train": {
        "epochs": 50,
        "batch_size": 8,
        "dropout": 0.3
    },
    "architectures": {
        "hidden_size": 256,
        "prots_input_dim": 1024,
        "num_layer": 3,
        "hidden_act": "gelu",
        "hidden_size": 256,
        "intermediate_size": 512,
        "num_attention_heads": 8
    },
    "prots": {
        "max_lengths": 1500,
        "max_chains": 13
    }
}

# 1. Load data

In [3]:
df_seq = pd.read_csv(config["paths"]["data_path"], sep="\t")
df_seq

Unnamed: 0,PDB,Sequence,BS
0,4x14,VEVLEVKTGVDSITEVECFLTPEMGDPDEHLRGFSKSISISDTFES...,"361,362,364,367,368,369,370,371,372,373,374,50..."
1,4ruu,TRDQNGTWEMESNENFEGYMKALDIDFATRKIAVRLTQTLVIDQDG...,"1,2,3,4,5,7,9,12,14,15,16,17,18,19,20,21,22,23..."
2,5hx8,IVSEKKPATEVDPTHFEKRFLKRIRDLGEGHFGKVELCRYDPEGDN...,"23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,4..."
3,2ymt,SAPIPDLKVFEREGVQLNLSFIRPPENPALLLITITATNFSEGDVT...,"16,18,32,34,47,48,49,50,51,52,53,54,55,56,57,5..."
4,4km2,MVGLIWAQATSGVIGRGGDIPWRLPEDQAHFREITMGHTIVMGRRT...,"2,3,4,5,6,7,8,13,14,15,17,18,19,20,21,22,23,24..."
...,...,...,...
9973,5ew9,QWALEDFEIGRPLGKGKFGNVYLAREKQSKFILALKVLFKAQLEKA...,"10,11,12,13,14,15,16,17,18,19,20,21,22,23,30,3..."
9974,4f7l,DKMDYDFKVKLSSERERVEDLFEYEGCKVGRGTYGHVYKAKRKDGK...,"26,27,28,29,30,31,32,33,34,35,36,37,38,39,49,5..."
9975,4elh,MIVSFMVAMDENRVIGKDNNLPWRLPSELQYVKKTTMGHPLIMGRK...,"3,4,5,6,7,8,9,12,13,14,15,16,17,18,19,20,21,22..."
9976,4o3a,GANKTVVVTTILESPYVMMKKNHEMLEGNERYEGYCVDLAAEIAKH...,"9,10,11,12,13,14,15,35,55,58,59,60,61,62,63,70..."


In [4]:
protein_feats = []
for pdbid in tqdm.tqdm(df_seq["PDB"].values):
    filepath = os.path.join(config["paths"]["input_dir_embedding"], f"{pdbid}.npy")
    protein_feats.append(np.load(filepath))

100%|███████████████████████████████████████████████████████████████████████| 9978/9978 [00:05<00:00, 1706.31it/s]


# 2. 5-Fold CV

In [5]:
class KFold:
    def __init__(self, n_splits, random_state):
        self.n_splits = n_splits
        self.random_state = random_state
        self.rng = np.random.default_rng(seed=self.random_state)
        
    def split(self, X):
        total_records = len(X)
        fold_size = total_records // self.n_splits
        remainder = total_records % self.n_splits
        
        ## index permutated
        permuted_idx = self.rng.permutation(total_records)
        
        ## fold
        folds = []
        i = 0
        while i < total_records:
            if len(folds) < remainder:
                i_next = i + fold_size + 1
            else:
                i_next = i + fold_size
            
            folds.append(permuted_idx[i:i_next])
            i = i_next

        ## generator
        for i in range(self.n_splits):
            test_idx = folds[i]
            train_idx = []
            for j in range(self.n_splits):
                if j != i:
                    train_idx.append(folds[j])
            train_idx = np.concatenate(train_idx)
            yield train_idx, test_idx

In [11]:
kf = KFold(5, 2025)

scores = []

for idx, (train_idx, val_idx) in enumerate(kf.split(df_seq)):
    ## validation data
    valid_IDs = df_seq.loc[val_idx, "PDB"].values.tolist()
    valid_BS = df_seq.loc[val_idx, "BS"].values.tolist()
    valid_seqs = df_seq.loc[val_idx, "Sequence"].values.tolist()
    valid_feats = [protein_feats[i] for i in val_idx]
    
    ## data loader
    valid_dataset = PocketDataset(valid_IDs, valid_feats, valid_seqs, valid_BS)
    valid_loader = Dataloader(valid_dataset, batch_size=config["train"]["batch_size"], shuffle=False)
    
    ## output_directory
    output_dir = os.path.join(config["paths"]["save_path"], f"fold{idx}")
    
    ## run
    print(f"Fold {idx} is running ...")
    trainer = Pseq2SitesPredictor(config)
    trainer.load_checkpoint(os.path.join(output_dir, "Pseq2Sites.pth"))
    preds = trainer.predict(valid_loader)
    
    ## confusion matrix
    results = get_results(
        valid_BS,
        convert_bs(preds),
        valid_seqs
    )

    ## save
    scores.append({
        "Fold": idx,
        "Precision": results[0],
        "Recall": results[1],
        "Specificity": results[2],
        "G-mean": results[3],
        "Accuracy": results[4],
        "F1-score": results[5],
        "F2-score": results[6],
    })

Fold 0 is running ...
Fold 1 is running ...
Fold 2 is running ...
Fold 3 is running ...
Fold 4 is running ...


In [12]:
df_scores = pd.DataFrame(scores)

In [13]:
df_scores

Unnamed: 0,Fold,Precision,Recall,Specificity,G-mean,Accuracy,F1-score,F2-score
0,0,0.63,0.92,0.89,0.9,0.91,0.75,0.85
1,1,0.69,0.89,0.92,0.91,0.91,0.78,0.84
2,2,0.63,0.9,0.89,0.9,0.9,0.74,0.83
3,3,0.65,0.9,0.9,0.9,0.9,0.76,0.84
4,4,0.65,0.91,0.9,0.9,0.91,0.76,0.84
