In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
import torch.nn.functional as F
from itertools import product
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
import csv
from pathlib import Path
import os
import subprocess
import random
from Bio import AlignIO, Phylo, SeqIO
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
import seaborn as sns
from sklearn.model_selection import train_test_split
from Bio.Align import MultipleSeqAlignment
from pathlib import Path
import time
import numpy as np
import joblib
from sklearn.pipeline          import Pipeline
from sklearn.preprocessing     import StandardScaler
from sklearn.linear_model      import LogisticRegression
from sklearn.ensemble          import RandomForestClassifier, StackingClassifier
from xgboost                   import XGBClassifier
from sklearn.model_selection   import StratifiedKFold, GridSearchCV
from sklearn.decomposition     import PCA
from sklearn.metrics           import (
    balanced_accuracy_score,
    roc_auc_score,
    classification_report,
    confusion_matrix,
    roc_curve,
    auc,
    f1_score
)

In [5]:
base = Path('/content/drive/MyDrive')
print("Top-level of MyDrive:", sorted(p.name for p in base.iterdir()))

DATA_DIR = base
print("Using DATA_DIR =", DATA_DIR)

Top-level of MyDrive: ['Colab Notebooks', 'Masters thesis (1).zip', 'PredictionModel (2).ipynb', 'PyRosetta', 'StackingEnsemble (1).ipynb', 'antigens.fasta', 'binding_results', 'data', 'test_neg.txt', 'test_neg_1.txt', 'test_neg_cdrh3.txt', 'test_pos.txt', 'test_pos_1.txt', 'test_pos_cdrh3.txt', 'train_neg (1).gdoc', 'train_neg.gdoc', 'train_neg.txt', 'train_neg_1.txt', 'train_neg_cdrh3.txt', 'train_pos.txt', 'train_pos_1.txt', 'train_pos_cdrh3.txt']
Using DATA_DIR = /content/drive/MyDrive


In [6]:
def load_antigen_sequences(fasta_path):
    sequences = {}
    with open(fasta_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences[record.id] = str(record.seq)
    return sequences

fasta_path = DATA_DIR/ "antigens.fasta"
ANTIGEN_SEQUENCES = load_antigen_sequences(fasta_path)

print(ANTIGEN_SEQUENCES)

{'SARS-CoV1': 'MFIFLLFLTLTSGSDLDRCTTFDDVQAPNYTQHTSSMRGVYYPDEIFRSDTLYLTQDLFLPFYSNVTGFHTINHTFGNPVIPFKDGIYFAATEKSNVVRGWVFGSTMNNKSQSVIIINNSTNVVIRACNFELCDNPFFAVSKPMGTQTHTMIFDNAFNCTFEYISDAFSLDVSEKSGNFKHLREFVFKNKDGFLYVYKGYQPIDVVRDLPSGFNTLKPIFKLPLGINITNFRAILTAFSPAQDIWGTSAAAYFVGYLKPTTFMLKYDENGTITDAVDCSQNPLAELKCSVKSFEIDKGIYQTSNFRVVPSGDVVRFPNITNLCPFGEVFNATKFPSVYAWERKKISNCVADYSVLYNSTFFSTFKCYGVSATKLNDLCFSNVYADSFVVKGDDVRQIAPGQTGVIADYNYKLPDDFMGCVLAWNTRNIDATSTGNYNYKYRYLRHGKLRPFERDISNVPFSPDGKPCTPPALNCYWPLNDYGFYTTTGIGYQPYRVVVLSFELLNAPATVCGPKLSTDLIKNQCVNFNFNGLTGTGVLTPSSKRFQPFQQFGRDVSDFTDSVRDPKTSEILDISPCSFGGVSVITPGTNASSEVAVLYQDVNCTDVSTAIHADQLTPAWRIYSTGNNVFQTQAGCLIGAEHVDTSYECDIPIGAGICASYHTVSLLRSTSQKSIVAYTMSLGADSSIAYSNNTIAIPTNFSISITTEVMPVSMAKTSVDCNMYICGDSTECANLLLQYGSFCTQLNRALSGIAAEQDRNTREVFAQVKQMYKTPTLKYFGGFNFSQILPDPLKPTKRSFIEDLLFNKVTLADAGFMKQYGECLGDINARDLICAQKFNGLTVLPPLLTDDMIAAYTAALVSGTATAGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKQIANQFNKAISQIQESLTTTSTALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQS

In [7]:
AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
DP = list(product(AA, AA))
DP_list = []
for i in DP:
    DP_list.append(str(i[0]) + str(i[1]))

AAindex_list = DP_list.copy()

def returnCKSAAPcode(query_seq, k):
    code_final = []

    for turns in range(k + 1):
        DP_dic = {}
        code = []
        code_order = []
        for i in DP_list:
            DP_dic[i] = 0
        for i in range(len(query_seq) - turns - 1):
            tmp_dp_1 = query_seq[i]                # first amino acid
            tmp_dp_2 = query_seq[i + turns + 1]    # second amino acid
            tmp_dp = tmp_dp_1 + tmp_dp_2           # combine them into a dipeptide string

            if tmp_dp in DP_dic.keys():
                DP_dic[tmp_dp] += 1
            else:
                DP_dic[tmp_dp] = 1

        for i, j in DP_dic.items():
            code.append(j / (len(query_seq) - turns - 1))

        for i in AAindex_list:
            code_order.append(code[DP_list.index(i)])

        code_final += code

    return code_final

def get_cksaap_length(sample_seq, k):
    code = returnCKSAAPcode(sample_seq, k)
    return len(code)

In [8]:
def load_tsv_pairs(pos_path, neg_path):
    pairs, labels = [], []
    for p, lab in [(pos_path, 1), (neg_path, 0)]:
        with open(p) as f:
            for line in f:
                ag_id, heavy, light = line.strip().split("\t")
                pairs.append((ag_id, heavy + light))
                labels.append(lab)
    return pairs, np.array(labels, dtype=int)

In [None]:
# ──────────────────────────────────────────────
# PARAMETERS & DATA LOADING
# ──────────────────────────────────────────────
print("Using DATA_DIR =", DATA_DIR)

OUTPUT_DIR = DATA_DIR / "binding_results"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
K       = 10
PCA_VAR = 0.95
CV_FOLDS = 5

train_pairs, y_train = load_tsv_pairs(DATA_DIR/"train_pos.txt", DATA_DIR/"train_neg.txt")
test_pairs,  y_test  = load_tsv_pairs(DATA_DIR/"test_pos.txt",  DATA_DIR/"test_neg.txt")

train_seqs = [(ANTIGEN_SEQUENCES[ag], ab) for ag, ab in train_pairs]
test_seqs  = [(ANTIGEN_SEQUENCES[ag], ab) for ag, ab in test_pairs]

def build_features(pairs, k=K, pca=None):
    X = []
    for ag, ab in pairs:
        c1    = returnCKSAAPcode(ag, k)
        c2    = returnCKSAAPcode(ab, k)
        diff  = np.abs(np.array(c1) - np.array(c2))
        prod  = np.array(c1) * np.array(c2)
        X.append(np.concatenate([c1, c2, diff, prod]))
    X = np.array(X, dtype=np.float32)
    return pca.transform(X) if pca else X

t0 = time.time()
X_train = build_features(train_seqs)
X_test  = build_features(test_seqs)
print(f"Feature building took {time.time() - t0:.2f}s")
joblib.dump((X_train, y_train, X_test, y_test), OUTPUT_DIR/"features.pkl")

t0 = time.time()
pca = PCA(n_components=PCA_VAR, svd_solver='full')
pca.fit(X_train)
X_train_pca = pca.transform(X_train)
X_test_pca  = pca.transform(X_test)
print(f"PCA fit & transform took {time.time() - t0:.2f}s")
joblib.dump(pca, OUTPUT_DIR/"pca_model.joblib")

LR_PARAM_GRID = {
    'clf__C': np.logspace(-3, 3, 7)
}

RF_PARAM_GRID = {
    'n_estimators': [100, 200, 300, 500],
    'max_depth': [4, 6, 8, 10],
    'min_samples_leaf': [1, 2, 4, 8],
    'class_weight': ['balanced']
}

XGB_PARAM_GRID = {
    'n_estimators': [100, 300],
    'max_depth': [4, 6],
    'learning_rate': [0.001, 0.05, 0.1],
}

# ──────────────────────────────────────────────
# LOGISTIC REGRESSION
# ──────────────────────────────────────────────
t0 = time.time()
lr_pipe = Pipeline([
    ('scale', StandardScaler()),
    ('clf', LogisticRegression(class_weight='balanced', max_iter=2000))
])
gs_lr = GridSearchCV(
    lr_pipe,
    LR_PARAM_GRID,
    cv=StratifiedKFold(CV_FOLDS, shuffle=True, random_state=42),
    scoring='roc_auc',
    n_jobs=-1
)
gs_lr.fit(X_train_pca, y_train)
lr = gs_lr.best_estimator_
print(f"LR tuning took {time.time() - t0:.2f}s")
joblib.dump(lr, OUTPUT_DIR/"lr_model.joblib")

# ──────────────────────────────────────────────
# RANDOM FOREST
# ──────────────────────────────────────────────
t0 = time.time()
gs_rf = GridSearchCV(
    RandomForestClassifier(),
    RF_PARAM_GRID,
    cv=StratifiedKFold(CV_FOLDS, shuffle=True, random_state=42),
    scoring='roc_auc',
    n_jobs=-1
)
gs_rf.fit(X_train_pca, y_train)
rf = gs_rf.best_estimator_
print(f"RF tuning took {time.time() - t0:.2f}s")
joblib.dump(rf, OUTPUT_DIR/"rf_model.joblib")

# ──────────────────────────────────────────────
# XGBOOST
# ──────────────────────────────────────────────
t0 = time.time()
gs_xgb = GridSearchCV(
    XGBClassifier(eval_metric='auc'),
    XGB_PARAM_GRID,
    cv=StratifiedKFold(CV_FOLDS, shuffle=True, random_state=42),
    scoring='roc_auc',
    n_jobs=-1
)
gs_xgb.fit(X_train_pca, y_train)
best_xgb = gs_xgb.best_estimator_
print(f"XGBoost tuning took {time.time() - t0:.2f}s")
joblib.dump(best_xgb, OUTPUT_DIR/"xgb_best.joblib")

# ──────────────────────────────────────────────
# STACKING
# ──────────────────────────────────────────────
t0 = time.time()
estimators = [('lr', lr), ('rf', rf), ('xgb', best_xgb)]
stack = StackingClassifier(
    estimators=estimators,
    final_estimator=LogisticRegression(),
    cv=StratifiedKFold(CV_FOLDS, shuffle=True, random_state=42),
    n_jobs=-1
)
stack.fit(X_train_pca, y_train)
print(f"Stacking training took {time.time() - t0:.2f}s")
joblib.dump(stack, OUTPUT_DIR/"stack_model.joblib")

# ──────────────────────────────────────────────
# EVALUATION FUNCTION
# ──────────────────────────────────────────────
def evaluate(name, model, X, y, prefix):
    probs = model.predict_proba(X)[:,1]
    preds = (probs >= 0.5).astype(int)
    bal   = balanced_accuracy_score(y, preds)
    roc   = roc_auc_score(y, probs)
    f1    = f1_score(y, preds)
    print(f"{name}: Balanced Acc={bal:.3f}, AUC={roc:.3f}, F1={f1:.3f}")
    print(classification_report(y, preds, digits=3))

    fpr, tpr, _ = roc_curve(y, probs)
    plt.figure(); plt.plot(fpr, tpr, label=f'AUC={roc:.3f}'); plt.plot([0,1],[0,1],'--')
    plt.title(f'{name} ROC'); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.legend()
    plt.savefig(OUTPUT_DIR/f"{prefix}_roc.png", dpi=200); plt.close()

    cm = confusion_matrix(y, preds)
    plt.figure(figsize=(4,4)); sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'{name} Confusion'); plt.tight_layout()
    plt.savefig(OUTPUT_DIR/f"{prefix}_confusion.png", dpi=200); plt.close()

# ──────────────────────────────────────────────
# EVALUATE MODELS
# ──────────────────────────────────────────────
for nm, mdl in [('lr', lr), ('rf', rf), ('xgb', best_xgb), ('stack', stack)]:
    evaluate(nm.upper(), mdl, X_test_pca, y_test, nm)

print("All done — results and models saved to", OUTPUT_DIR)