In [16]:
!pip install Bio

from pathlib import Path
from Bio import SeqIO
from itertools import product
!pip install transformers torch sentencepiece scikit-learn

import numpy as np
import torch
from pathlib import Path
from transformers import BertModel, BertTokenizerFast
from sklearn.pipeline          import Pipeline
from sklearn.preprocessing     import StandardScaler
from sklearn.linear_model      import LogisticRegressionCV
from sklearn.model_selection   import StratifiedKFold, train_test_split
from sklearn.metrics           import (
    balanced_accuracy_score,
    roc_auc_score,
    classification_report,
    confusion_matrix
)



In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
drive.mount('/content/drive')

base = Path('/content/drive/MyDrive')
print("Top-level of MyDrive:", sorted(p.name for p in base.iterdir()))

DATA_DIR = base
print("Using DATA_DIR =", DATA_DIR)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Top-level of MyDrive: ['Colab Notebooks', 'Masters thesis (1).zip', 'PredictionModel (2).ipynb', 'PyRosetta', 'StackingEnsemble (1).ipynb', 'antigens.fasta', 'binding_results', 'data', 'test_neg.txt', 'test_neg_1.txt', 'test_neg_cdrh3.txt', 'test_pos.txt', 'test_pos_1.txt', 'test_pos_cdrh3.txt', 'train_neg (1).gdoc', 'train_neg.gdoc', 'train_neg.txt', 'train_neg_1.txt', 'train_neg_cdrh3.txt', 'train_pos.txt', 'train_pos_1.txt', 'train_pos_cdrh3.txt']
Using DATA_DIR = /content/drive/MyDrive


In [19]:
def load_antigen_sequences(fasta_path):
    sequences = {}
    with open(fasta_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences[record.id] = str(record.seq)
    return sequences

fasta_path = DATA_DIR/ "antigens.fasta"
ANTIGEN_SEQUENCES = load_antigen_sequences(fasta_path)

print(ANTIGEN_SEQUENCES)

{'SARS-CoV1': 'MFIFLLFLTLTSGSDLDRCTTFDDVQAPNYTQHTSSMRGVYYPDEIFRSDTLYLTQDLFLPFYSNVTGFHTINHTFGNPVIPFKDGIYFAATEKSNVVRGWVFGSTMNNKSQSVIIINNSTNVVIRACNFELCDNPFFAVSKPMGTQTHTMIFDNAFNCTFEYISDAFSLDVSEKSGNFKHLREFVFKNKDGFLYVYKGYQPIDVVRDLPSGFNTLKPIFKLPLGINITNFRAILTAFSPAQDIWGTSAAAYFVGYLKPTTFMLKYDENGTITDAVDCSQNPLAELKCSVKSFEIDKGIYQTSNFRVVPSGDVVRFPNITNLCPFGEVFNATKFPSVYAWERKKISNCVADYSVLYNSTFFSTFKCYGVSATKLNDLCFSNVYADSFVVKGDDVRQIAPGQTGVIADYNYKLPDDFMGCVLAWNTRNIDATSTGNYNYKYRYLRHGKLRPFERDISNVPFSPDGKPCTPPALNCYWPLNDYGFYTTTGIGYQPYRVVVLSFELLNAPATVCGPKLSTDLIKNQCVNFNFNGLTGTGVLTPSSKRFQPFQQFGRDVSDFTDSVRDPKTSEILDISPCSFGGVSVITPGTNASSEVAVLYQDVNCTDVSTAIHADQLTPAWRIYSTGNNVFQTQAGCLIGAEHVDTSYECDIPIGAGICASYHTVSLLRSTSQKSIVAYTMSLGADSSIAYSNNTIAIPTNFSISITTEVMPVSMAKTSVDCNMYICGDSTECANLLLQYGSFCTQLNRALSGIAAEQDRNTREVFAQVKQMYKTPTLKYFGGFNFSQILPDPLKPTKRSFIEDLLFNKVTLADAGFMKQYGECLGDINARDLICAQKFNGLTVLPPLLTDDMIAAYTAALVSGTATAGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKQIANQFNKAISQIQESLTTTSTALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQS

In [20]:
AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
DP = list(product(AA, AA))
DP_list = []
for i in DP:
    DP_list.append(str(i[0]) + str(i[1]))

AAindex_list = DP_list.copy()

def returnCKSAAPcode(query_seq, k):
    code_final = []
    for turns in range(k + 1):
        DP_dic = {}
        code = []
        code_order = []

        for i in DP_list:
            DP_dic[i] = 0

        for i in range(len(query_seq) - turns - 1):
            tmp_dp_1 = query_seq[i]                # first amino acid
            tmp_dp_2 = query_seq[i + turns + 1]    # second amino acid
            tmp_dp = tmp_dp_1 + tmp_dp_2           # combine them into a dipeptide string

            if tmp_dp in DP_dic.keys():
                DP_dic[tmp_dp] += 1
            else:
                DP_dic[tmp_dp] = 1

        for i, j in DP_dic.items():
            code.append(j / (len(query_seq) - turns - 1))

        for i in AAindex_list:
            code_order.append(code[DP_list.index(i)])
        code_final += code

    return code_final

def get_cksaap_length(sample_seq, k):
    code = returnCKSAAPcode(sample_seq, k)
    return len(code)

In [None]:
OUTPUT_DIR  = DATA_DIR/"binding_results"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

def load_tsv_pairs(pos_path, neg_path):
    pairs, labels = [], []
    for p, lab in [(pos_path, 1), (neg_path, 0)]:
        with open(p) as f:
            for line in f:
                ag_id, heavy, light = line.strip().split("\t")
                pairs.append((ag_id, heavy + light))
                labels.append(lab)
    return pairs, np.array(labels, dtype=int)

train_pairs, y_train = load_tsv_pairs(DATA_DIR/"train_pos.txt", DATA_DIR/"train_neg.txt")
test_pairs,  y_test  = load_tsv_pairs(DATA_DIR/"test_pos.txt",  DATA_DIR/"test_neg.txt")

print(f"Train: {len(train_pairs)}  Pos/Neg = {np.bincount(y_train)}")
print(f"Test:  {len(test_pairs)}  Pos/Neg = {np.bincount(y_test)}")

train_seqs = [(ANTIGEN_SEQUENCES[ag], ab) for ag,ab in train_pairs]
test_seqs  = [(ANTIGEN_SEQUENCES[ag], ab) for ag,ab in test_pairs]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model     = BertModel.from_pretrained("Rostlab/prot_bert").to(device)
model.eval()

def embed_seq(seq, max_len=512):
    """Return mean‐pooled ProtBert embedding for a single sequence."""
    tokens = " ".join(list(seq))
    inputs = tokenizer(tokens, return_tensors="pt", add_special_tokens=True)
    inputs = {k:v.to(device) for k,v in inputs.items()}
    with torch.no_grad():
        out = model(**inputs).last_hidden_state  # (1, L, 1024)
    emb = out[0, 1:-1].mean(dim=0).cpu().numpy()  # (1024,)
    return emb

def build_features(pairs, k=10):
    X_ck, X_pb, ys = [], [], []
    for ag_seq, ab_seq in pairs:
        c1 = returnCKSAAPcode(ag_seq, k)
        c2 = returnCKSAAPcode(ab_seq,   k)
        diff = [abs(a-b) for a,b in zip(c1,c2)]
        prod = [   a*b   for a,b in zip(c1,c2)]
        ck_feat = c1 + c2 + diff + prod

        e1 = embed_seq(ag_seq)
        e2 = embed_seq(ab_seq)

        X_ck.append(ck_feat)
        X_pb.append(np.concatenate([e1, e2]))
    return np.vstack(X_ck), np.vstack(X_pb)

best_k = 10
print("Building train features…")
X_train_ck, X_train_pb = build_features(train_seqs, best_k)
print("Building test features…")
X_test_ck,  X_test_pb  = build_features(test_seqs,  best_k)

X_train = np.hstack([X_train_ck, X_train_pb])
X_test  = np.hstack([X_test_ck,  X_test_pb ])

print("X_train shape =", X_train.shape)
print("X_test  shape =", X_test.shape)

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
pipe = Pipeline([
    ("scale", StandardScaler()),
    ("clf",   LogisticRegressionCV(
        Cs=[0.01,0.1,1,10],
        cv=cv,
        scoring="roc_auc",
        class_weight="balanced",
        max_iter=2000,
        n_jobs=-1
    ))
])

aucs = []
for tr, val in cv.split(X_train, y_train):
    pipe.fit(X_train[tr], y_train[tr])
    probs = pipe.predict_proba(X_train[val])[:,1]
    aucs.append(roc_auc_score(y_train[val], probs))
print(f"CV AUC = {np.mean(aucs):.3f} ± {np.std(aucs):.3f}")

pipe.fit(X_train, y_train)
probs = pipe.predict_proba(X_test)[:,1]
preds = pipe.predict(X_test)

print("\n=== Test performance ===")
print("Balanced Acc:", balanced_accuracy_score(y_test, preds))
print("AUC:         ", roc_auc_score(y_test, probs))
print(classification_report(y_test, preds, digits=3))
print(confusion_matrix(y_test, preds))

def evaluate(name, model, X, y, prefix, output_dir):
    probs = model.predict_proba(X)[:,1]
    preds = (probs >= 0.5).astype(int)
    bal = balanced_accuracy_score(y, preds)
    roc = roc_auc_score(y, probs)
    prec_recall_f1 = classification_report(y, preds, digits=3)
    cm = confusion_matrix(y, preds)

    print(f"{name}: Balanced Acc={bal:.3f}, AUC={roc:.3f}")
    print(prec_recall_f1)

    # ROC
    fpr, tpr, _ = roc_curve(y, probs)
    plt.figure(); plt.plot(fpr, tpr, label=f'AUC={roc:.3f}'); plt.plot([0,1],[0,1],'--')
    plt.title(f'{name} ROC'); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.legend()
    plt.savefig(output_dir/f"{prefix}_roc.png", dpi=200); plt.close()

    # Confusion
    plt.figure(figsize=(4,4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'{name} Confusion'); plt.tight_layout()
    plt.savefig(output_dir/f"{prefix}_confusion.png", dpi=200); plt.close()

In [None]:
evaluate(
  name="ProtBERT+CKSAAP",
  model=pipe,
  X=X_test,
  y=y_test,
  prefix="protbert_cksaap",
  output_dir=OUTPUT_DIR
)