In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
import torch.nn.functional as F
import numpy as np
from itertools import product
from pathlib import Path
import random
import matplotlib.pyplot as plt
from sklearn.metrics import balanced_accuracy_score, accuracy_score, classification_report, precision_score, recall_score, f1_score, roc_auc_score, roc_curve, accuracy_score, roc_auc_score, f1_score, roc_curve, auc, precision_recall_curve, confusion_matrix
from sklearn.model_selection import StratifiedKFold
import seaborn as sns
from Bio import SeqIO
from torch.nn.utils.rnn import pad_sequence
import csv
import os
import subprocess
import random
from Bio import AlignIO, Phylo, SeqIO
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
from sklearn.model_selection import train_test_split
from Bio.Align import MultipleSeqAlignment
import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split
from Bio import SeqIO
import csv

from sklearn.metrics import (
    balanced_accuracy_score,
    roc_auc_score,
    classification_report,
    confusion_matrix
)

In [43]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed()

In [44]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [45]:
def load_antigen_sequences(fasta_path):
    sequences = {}
    with open(fasta_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences[record.id] = str(record.seq)
    return sequences

fasta_path = Path.cwd() / "antigens.fasta"
ANTIGEN_SEQUENCES = load_antigen_sequences(fasta_path)

print(ANTIGEN_SEQUENCES)

{'SARS-CoV1': 'MFIFLLFLTLTSGSDLDRCTTFDDVQAPNYTQHTSSMRGVYYPDEIFRSDTLYLTQDLFLPFYSNVTGFHTINHTFGNPVIPFKDGIYFAATEKSNVVRGWVFGSTMNNKSQSVIIINNSTNVVIRACNFELCDNPFFAVSKPMGTQTHTMIFDNAFNCTFEYISDAFSLDVSEKSGNFKHLREFVFKNKDGFLYVYKGYQPIDVVRDLPSGFNTLKPIFKLPLGINITNFRAILTAFSPAQDIWGTSAAAYFVGYLKPTTFMLKYDENGTITDAVDCSQNPLAELKCSVKSFEIDKGIYQTSNFRVVPSGDVVRFPNITNLCPFGEVFNATKFPSVYAWERKKISNCVADYSVLYNSTFFSTFKCYGVSATKLNDLCFSNVYADSFVVKGDDVRQIAPGQTGVIADYNYKLPDDFMGCVLAWNTRNIDATSTGNYNYKYRYLRHGKLRPFERDISNVPFSPDGKPCTPPALNCYWPLNDYGFYTTTGIGYQPYRVVVLSFELLNAPATVCGPKLSTDLIKNQCVNFNFNGLTGTGVLTPSSKRFQPFQQFGRDVSDFTDSVRDPKTSEILDISPCSFGGVSVITPGTNASSEVAVLYQDVNCTDVSTAIHADQLTPAWRIYSTGNNVFQTQAGCLIGAEHVDTSYECDIPIGAGICASYHTVSLLRSTSQKSIVAYTMSLGADSSIAYSNNTIAIPTNFSISITTEVMPVSMAKTSVDCNMYICGDSTECANLLLQYGSFCTQLNRALSGIAAEQDRNTREVFAQVKQMYKTPTLKYFGGFNFSQILPDPLKPTKRSFIEDLLFNKVTLADAGFMKQYGECLGDINARDLICAQKFNGLTVLPPLLTDDMIAAYTAALVSGTATAGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKQIANQFNKAISQIQESLTTTSTALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQS

In [46]:
def load_sequence_pairs(positive_file, negative_file):
    seq_pairs = []
    labels = []

    with open(positive_file, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            if len(row) < 3:
                continue
            antigen_id = row[0].strip()
            heavy = row[1].strip()
            light = row[2].strip()

            # Lookup the antigen sequence
            antigen_seq = ANTIGEN_SEQUENCES.get(antigen_id)
            if antigen_seq is None:
                continue

            antibody_seq = heavy + light

            # Append merged (antigen, antibody) pair
            seq_pairs.append((antigen_seq, antibody_seq))
            labels.append(1.0)

    with open(negative_file, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            if len(row) < 3:
                continue
            antigen_id = row[0].strip()
            heavy = row[1].strip()
            light = row[2].strip()

            antigen_seq = ANTIGEN_SEQUENCES.get(antigen_id)
            if antigen_seq is None:
                continue

            antibody_seq = heavy + light

            seq_pairs.append((antigen_seq, antibody_seq))
            labels.append(0.0)

    return seq_pairs, labels

In [47]:
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
AA_TO_INDEX = {aa: idx for idx, aa in enumerate(AMINO_ACIDS)}

In [48]:
class OneHotSequenceDataset(Dataset):
    def __init__(self, seq_pairs, labels):
        self.seq_pairs = seq_pairs
        self.labels = labels

    def __len__(self):
        return len(self.seq_pairs)

    def __getitem__(self, idx):
        seq1, seq2 = self.seq_pairs[idx]

        if seq1 in globals().get("ANTIGEN_SEQUENCES", {}):
            seq1 = ANTIGEN_SEQUENCES[seq1]

        merged_seq = seq1 + seq2

        enc_seq = one_hot_encode(merged_seq)  # (seq_len, 20)

        return (
            torch.tensor(enc_seq, dtype=torch.float32),  # (seq_len, 20)
            torch.tensor(self.labels[idx], dtype=torch.float32)
        )

In [49]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size=20, hidden_size=128, num_layers=1):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, (hn, cn) = self.lstm(x)
        out = out[:, -1, :] 
        out = self.fc(out)
        return out.squeeze(1)

In [55]:
def collate_batch(batch):
    sequences, labels = zip(*batch)
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0.0)
    labels = torch.stack(labels)

    return padded_sequences, labels

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

AMINO_ACIDS  = "ACDEFGHIKLMNPQRSTVWY"
AA_TO_INDEX  = {aa:i for i,aa in enumerate(AMINO_ACIDS)}

def clean_sequence(seq: str) -> str:
    return "".join([aa for aa in seq if aa in AA_TO_INDEX])

def one_hot_encode(seq: str) -> np.ndarray:
    L = len(seq)
    mat = np.zeros((L, len(AMINO_ACIDS)), dtype=np.float32)
    for i, aa in enumerate(seq):
        mat[i, AA_TO_INDEX[aa]] = 1.0
    return mat

class OneHotSequenceDataset(Dataset):
    def __init__(self, seq_pairs, labels, antigen_dict):
        self.seq_pairs    = seq_pairs
        self.labels       = labels
        self.antigen_dict = antigen_dict

    def __len__(self):
        return len(self.seq_pairs)

    def __getitem__(self, idx):
        ag_id, ab_seq = self.seq_pairs[idx]
        ag_seq = self.antigen_dict[ag_id]
        merged = clean_sequence(ag_seq) + clean_sequence(ab_seq)
        oh     = one_hot_encode(merged)
        return torch.tensor(oh), torch.tensor(self.labels[idx], dtype=torch.float32)

def load_antigen_sequences(fasta_path: Path):
    d = {}
    for rec in SeqIO.parse(str(fasta_path), "fasta"):
        d[rec.id] = str(rec.seq)
    return d

def load_sequence_pairs(pos_path: Path, neg_path: Path):
    pairs, labels = [], []
    for path, label in [(pos_path, 1.0), (neg_path, 0.0)]:
        with open(path, newline="") as f:
            reader = csv.reader(f, delimiter="\t")
            for row in reader:
                if len(row) >= 3:
                    ag_id = row[0].strip()
                    heavy = row[1].strip()
                    light = row[2].strip()
                    pairs.append((ag_id, heavy + light))
                    labels.append(label)
    return pairs, labels

class BiLSTMClassifier(nn.Module):
    def __init__(self, input_size=20, hidden_size=128, num_layers=2, dropout=0.3):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            batch_first=True, bidirectional=True, dropout=dropout
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * 2, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        last   = out[:, -1, :]
        last   = self.dropout(last)
        return self.fc(last).squeeze(1)

In [None]:
if __name__ == "__main__":
    WORKDIR      = Path.cwd()
    antigen_fasta= WORKDIR / "antigens.fasta"
    TRAIN_POS    = WORKDIR / "AbAgIntPre/CoV-AbDab/train_pos.txt"
    TRAIN_NEG    = WORKDIR / "AbAgIntPre/CoV-AbDab/train_neg.txt"
    TEST_POS     = WORKDIR / "AbAgIntPre/CoV-AbDab/test_pos.txt"
    TEST_NEG     = WORKDIR / "AbAgIntPre/CoV-AbDab/test_neg.txt"

    antigen_dict = load_antigen_sequences(antigen_fasta)

    seq_pairs, labels = load_sequence_pairs(TRAIN_POS, TRAIN_NEG)
    print(f"Total train+val samples: {len(seq_pairs)}")
    full_ds = OneHotSequenceDataset(seq_pairs, labels, antigen_dict)

    n_train = int(0.9 * len(full_ds))
    n_val   = len(full_ds) - n_train
    train_ds, val_ds = random_split(full_ds, [n_train, n_val])
    print(f" Train: {len(train_ds)}, Val: {len(val_ds)}")

    # test
    test_pairs, test_labels = load_sequence_pairs(TEST_POS, TEST_NEG)
    test_ds = OneHotSequenceDataset(test_pairs, test_labels, antigen_dict)
    print(f" Test: {len(test_ds)}")

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  collate_fn=collate_batch)
    val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, collate_fn=collate_batch)
    test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False, collate_fn=collate_batch)


    model     = BiLSTMClassifier().to(device)
    optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    num_epochs    = 100
    best_val_loss = float('inf')
    patience      = 10
    wait          = 0

    for epoch in range(1, num_epochs+1):
        model.train()
        train_loss = 0.0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            logits = model(x_batch)
            loss   = criterion(logits, y_batch)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        # val
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                val_loss += criterion(model(x_batch), y_batch).item()
        val_loss /= len(val_loader)

        print(f"Epoch {epoch}/{num_epochs} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_bilstm_model.pth")
            print(" ↳ Saved best model")
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print(f"No improvement for {patience} epochs, stopping.")
                break

    model.load_state_dict(torch.load("best_bilstm_model.pth"))
    model.eval()

    all_logits = []
    all_labels = []
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            all_logits.append(model(x_batch).cpu())
            all_labels.append(y_batch.cpu())
    logits = torch.cat(all_logits)
    y_true = torch.cat(all_labels).numpy()
    y_prob = torch.sigmoid(logits).numpy()
    y_pred = (y_prob >= 0.5).astype(int)

    bal_acc = balanced_accuracy_score(y_true, y_pred)
    auc     = roc_auc_score(y_true, y_prob)
    report  = classification_report(y_true, y_pred, digits=3)
    cm      = confusion_matrix(y_true, y_pred)

    print("\n=== TEST SET PERFORMANCE ===")
    print(f"Balanced Accuracy: {bal_acc:.3f}")
    print(f"AUC:               {auc:.3f}")
    print("Classification Report:\n", report)
    print("Confusion Matrix:\n", cm)

In [62]:
def plot_roc(fpr, tpr, roc_auc, name, out_path):
    fig, ax = plt.subplots(figsize=(6,6))
    
    ax.plot(fpr, tpr, label=f"AUC={roc_auc:.3f}", linewidth=2)
    ax.plot([0,1], [0,1], '--', color='orange', linewidth=2)
    
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")
    ax.set_title(f"{name} ROC")
    
    ax.legend(loc='upper left', fontsize='small', frameon=True)

    ax.set_aspect('equal', 'box')
    
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)


def evaluate_torch(name, model, data_loader, device, output_dir):
    model.eval()
    all_probs = []
    all_preds = []
    all_labels = []

    t0 = time.time()
    with torch.no_grad():
        for x_batch, y_batch in data_loader:
            x_batch = x_batch.to(device)
            logits  = model(x_batch)
            probs   = torch.sigmoid(logits).cpu().numpy().ravel()
            preds   = (probs >= 0.5).astype(int)

            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(y_batch.cpu().numpy().astype(int))
    elapsed = time.time() - t0

    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    roc_auc = roc_auc_score(all_labels, all_probs)
    report  = classification_report(all_labels, all_preds, digits=3)
    cm      = confusion_matrix(all_labels, all_preds)

    print(f"\n--- {name} Evaluation ---")
    print(f"Predict time: {elapsed:.3f}s")
    print(f"Balanced Acc: {bal_acc:.3f}    ROC-AUC: {roc_auc:.3f}")
    print(report)

    output_dir.mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(4,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
    plt.title(f"{name} Confusion")
    plt.tight_layout()
    plt.savefig(output_dir / f"{name}_confusion.png", dpi=200)
    plt.close()

    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    plt.figure(); plt.plot(fpr, tpr, label=f'AUC={roc_auc:.3f}'); plt.plot([0,1],[0,1],'--');
    plt.title(f'{name} ROC'); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.legend();
    plt.savefig(OUTPUT_DIR/f"bilstm_roc.png", dpi=200); plt.close()

###### 

In [None]:
OUTPUT_DIR = WORKDIR / "outputs"
evaluate_torch(
    name="BiLSTM",
    model=model,
    data_loader=val_loader,
    device=device,
    output_dir=OUTPUT_DIR
)