In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
import torch.nn.functional as F
import numpy as np
from itertools import product
import random
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve, accuracy_score, roc_auc_score, f1_score, roc_curve, auc, precision_recall_curve, confusion_matrix
from sklearn.model_selection import StratifiedKFold
import csv
from pathlib import Path
import os
import subprocess
from Bio import AlignIO, Phylo, SeqIO
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
import seaborn as sns
from sklearn.model_selection import train_test_split
from Bio.Align import MultipleSeqAlignment
from collections import Counter
import time
from sklearn.metrics import (
    balanced_accuracy_score, classification_report
)

In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def load_antigen_sequences(fasta_path):
    sequences = {}
    with open(fasta_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences[record.id] = str(record.seq)
    return sequences

fasta_path = Path.cwd() / "antigens.fasta"
ANTIGEN_SEQUENCES = load_antigen_sequences(fasta_path)

print(ANTIGEN_SEQUENCES)

In [None]:
def load_data(positive_file, negative_file):
    """
    Each line of file:
      antigen \t heavy_chain \t light_chain
    """
    seq_pairs, labels = [], []

    # Read positive file
    with open(positive_file, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            if len(row) < 3:
                continue
            antigen = row[0]
            antibody = row[1] + row[2]  # heavy + light
            seq_pairs.append((antigen, antibody))
            labels.append(1.0)

    # Read negative file
    with open(negative_file, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            if len(row) < 3:
                continue
            antigen = row[0]
            antibody = row[1] + row[2]
            seq_pairs.append((antigen, antibody))
            labels.append(0.0)

    return seq_pairs, labels

In [None]:
AA = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
DP = list(product(AA, AA))
DP_list = []
for i in DP:
    DP_list.append(str(i[0]) + str(i[1]))

AAindex_list = DP_list.copy()

def returnCKSAAPcode(query_seq, k):
    code_final = []
    for turns in range(k + 1):
        
        DP_dic = {} 
        code = []
        code_order = []
        for i in DP_list:
            DP_dic[i] = 0
        
        for i in range(len(query_seq) - turns - 1):
            tmp_dp_1 = query_seq[i]                # first amino acid
            tmp_dp_2 = query_seq[i + turns + 1]    # second amino acid
            tmp_dp = tmp_dp_1 + tmp_dp_2           # combine them into a dipeptide string
            
            if tmp_dp in DP_dic.keys():
                DP_dic[tmp_dp] += 1
            else:
                DP_dic[tmp_dp] = 1
        
        for i, j in DP_dic.items():
            code.append(j / (len(query_seq) - turns - 1))
        
        for i in AAindex_list:
            code_order.append(code[DP_list.index(i)])
        
        code_final += code
    
    return code_final

def get_cksaap_length(sample_seq, k):
    code = returnCKSAAPcode(sample_seq, k)
    return len(code)

In [None]:
class SequenceDataset(Dataset):
    def __init__(self, seq_pairs, labels, k):
        self.seq_pairs = seq_pairs
        self.labels    = labels
        self.k         = k

    def __len__(self):
        return len(self.seq_pairs)

    def __getitem__(self, idx):
        seq1, seq2 = self.seq_pairs[idx]
        if seq1 in ANTIGEN_SEQUENCES:
            seq1 = ANTIGEN_SEQUENCES[seq1]

        enc_seq1 = returnCKSAAPcode(seq1, self.k)
        enc_seq2 = returnCKSAAPcode(seq2, self.k)

        return (
            torch.tensor(enc_seq1, dtype=torch.float32),
            torch.tensor(enc_seq2, dtype=torch.float32),
            torch.tensor(self.labels[idx], dtype=torch.long)   # ← HERE
        )

In [None]:
class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

class SiameseNetwork(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        c,h,w = input_shape
        self.cnn = nn.Sequential(
            nn.Conv2d(c, 10, 3, 1), nn.BatchNorm2d(10), nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Conv2d(10,20,3,1),   nn.BatchNorm2d(20), nn.LeakyReLU(),
            Flatten()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            flat  = self.cnn(dummy).shape[1]
        self.fc = nn.Sequential(
            nn.Linear(flat*4, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2),
        )

    def forward_once(self, x):   return self.cnn(x)
    def forward(self, x1, x2):
        o1 = self.forward_once(x1)
        o2 = self.forward_once(x2)
        diff = torch.abs(o1 - o2)
        prod = o1 * o2
        return self.fc(torch.cat([o1, o2, diff, prod], dim=1))

In [None]:
print("CWD:", os.getcwd())

print("Contents:", os.listdir())

sub = "AbAgIntPre/CoV-AbDab"
if os.path.isdir(sub):
    print(f"\nContents of {sub}:", os.listdir(sub))
else:
    print(f"\nNo folder named {sub} here.")


In [None]:
# The collate_fn function customizes how individual samples are combined into a batch.
# It performs the following steps:
# 1. Unpacks the batch (a list of tuples) into separate lists for seq1, seq2, and labels.
# 2. Reshapes each sequence tensor from its original shape (e.g., (L,)) into the expected CNN input shape (channels, height, width).
# 3. Uses torch.stack to combine the individual tensors into a single batch tensor.
#    - torch.stack takes a list of tensors (all of the same shape) and stacks them along a new dimension,
#      resulting in a tensor of shape (batch_size, ...).
def get_collate_fn(input_shape):
    def collate_fn(batch):
        seq1_list, seq2_list, label_list = zip(*batch)
        seq1_tensor = torch.stack([x.view(input_shape) for x in seq1_list])
        seq2_tensor = torch.stack([x.view(input_shape) for x in seq2_list])
        labels_tensor = torch.stack(label_list)
        return seq1_tensor, seq2_tensor, labels_tensor
    return collate_fn

In [None]:
class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

class SiameseNetwork(nn.Module):
    def __init__(self, input_shape):
        super().__init__()
        c, h, w = input_shape
        self.cnn = nn.Sequential(
            nn.Conv2d(c, 10, 3, 1), nn.BatchNorm2d(10), nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Conv2d(10, 20, 3, 1), nn.BatchNorm2d(20), nn.LeakyReLU(),
            Flatten()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            flat = self.cnn(dummy).shape[1]
        self.fc = nn.Sequential(
            nn.Linear(flat * 4, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2),
        )

    def forward_once(self, x):
        return self.cnn(x)

    def forward(self, x1, x2):
        o1 = self.forward_once(x1)
        o2 = self.forward_once(x2)
        diff = torch.abs(o1 - o2)
        prod = o1 * o2
        return self.fc(torch.cat([o1, o2, diff, prod], dim=1))


def plot_roc_curves(roc_data, title=None):
    plt.figure()
    for i, (fpr, tpr, fold_auc) in enumerate(roc_data, 1):
        plt.plot(fpr, tpr, label=f"Fold {i} (AUC={fold_auc:.3f})")
    plt.plot([0, 1], [0, 1], "k--", linewidth=1)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    if title:
        plt.title(title)
    plt.legend()
    plt.show()

def plot_pr_curves(pr_data, title=None):
    plt.figure()
    for i, (prec, rec, pr_auc) in enumerate(pr_data, 1):
        plt.plot(rec, prec, label=f"Fold {i} (AUC={pr_auc:.3f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    if title:
        plt.title(title)
    plt.legend()
    plt.show()


def cross_validate_model(seq_pairs, labels, k=4, folds=5,
                         epochs=50, bs=32, lr=1e-3, patience=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_shape = (k + 1, 20, 20)
    skf = StratifiedKFold(n_splits=folds, shuffle=True, random_state=42)

    best_aucs = []
    roc_data = []
    pr_data = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(seq_pairs, labels), 1):
        print(f"\n=== Fold {fold}/{folds} ===")
        train_ds = SequenceDataset([seq_pairs[i] for i in train_idx],
                                   [labels[i] for i in train_idx], k)
        val_ds   = SequenceDataset([seq_pairs[i] for i in val_idx],
                                   [labels[i] for i in val_idx], k)
        collate = get_collate_fn(input_shape)
        tr_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate)
        val_loader = DataLoader(val_ds, batch_size=bs, shuffle=False, collate_fn=collate)

        model = SiameseNetwork(input_shape).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',
                                                         factor=0.5, patience=2, min_lr=1e-6)

        best_auc = 0.0
        wait = 0

        # Training loop
        for epoch in range(1, epochs + 1):
            model.train()
            for x1, x2, y in tr_loader:
                x1, x2, y = x1.to(device), x2.to(device), y.to(device)
                optimizer.zero_grad()
                logits = model(x1, x2)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()

            # Validation
            model.eval()
            all_y, all_p = [], []
            with torch.no_grad():
                for x1, x2, y in val_loader:
                    x1, x2 = x1.to(device), x2.to(device)
                    logits = model(x1, x2)
                    probs = torch.softmax(logits, dim=1)[:, 1]
                    all_p.extend(probs.cpu().tolist())
                    all_y.extend(y.tolist())

            fold_auc = roc_auc_score(all_y, all_p)
            scheduler.step(fold_auc)
            print(f"Epoch {epoch}: Val AUC = {fold_auc:.4f}  LR = {optimizer.param_groups[0]['lr']:.2e}")

            # Early stopping logic
            if fold_auc > best_auc:
                best_auc = fold_auc
                wait = 0
                torch.save(model.state_dict(), f"best_fold{fold}.pth")
            else:
                wait += 1
                if wait >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break

        model.load_state_dict(torch.load(f"best_fold{fold}.pth"))
        model.eval()
        all_y, all_p = [], []
        with torch.no_grad():
            for x1, x2, y in val_loader:
                x1, x2 = x1.to(device), x2.to(device)
                logits = model(x1, x2)
                probs = torch.softmax(logits, dim=1)[:, 1]
                all_p.extend(probs.cpu().tolist())
                all_y.extend(y.tolist())

        best_aucs.append(best_auc)
        fpr, tpr, _ = roc_curve(all_y, all_p)
        roc_data.append((fpr, tpr, best_auc))
        prec, rec, _ = precision_recall_curve(all_y, all_p)
        pr_auc = auc(rec, prec)
        pr_data.append((prec, rec, pr_auc))

    mean_auc = np.mean(best_aucs)
    std_auc = np.std(best_aucs)
    print(f"\n>>> Mean CV AUC = {mean_auc:.3f} ± {std_auc:.3f}")

    return best_aucs, roc_data, pr_data

In [None]:
def plot_confusion(cm, title, out_path, classes=['Neg','Pos']):
    plt.figure(figsize=(4,4))
    plt.imshow(cm, interpolation='nearest', cmap='Blues')
    plt.title(f"{title} Confusion")
    plt.colorbar()
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, cm[i,j], ha='center', va='center',
                     color='white' if cm[i,j] > cm.max()/2 else 'black')
    plt.xticks([0,1], classes)
    plt.yticks([0,1], classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()

def plot_roc(fpr, tpr, roc_auc, title, out_path):
    fig, ax = plt.subplots(figsize=(4,4))
    ax.plot(fpr, tpr, label=f"AUC={roc_auc:.3f}", linewidth=2)
    ax.plot([0,1],[0,1],'--', color='orange', linewidth=2)
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")
    ax.set_title(f"{title} ROC")
    ax.legend(loc='upper left', frameon=True, fontsize='small')
    ax.set_aspect('equal', 'box')
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)

OUTPUT_DIR = Path.cwd() / "final_model_outputs"
OUTPUT_DIR.mkdir(exist_ok=True)

k = 4
input_shape = (k+1, 20, 20)
bs = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 30
lr = 1e-4

train_pos_file = Path.cwd() / "AbAgIntPre" / "CoV-AbDab" / "train_pos.txt"
train_neg_file = Path.cwd() / "AbAgIntPre" / "CoV-AbDab" / "train_neg.txt"
test_pos_file  = Path.cwd() / "AbAgIntPre" / "CoV-AbDab" / "test_pos.txt"
test_neg_file  = Path.cwd() / "AbAgIntPre" / "CoV-AbDab" / "test_neg.txt"

seq_pairs_train, train_labels = load_data(train_pos_file, train_neg_file)
seq_pairs_test,  test_labels  = load_data(test_pos_file,  test_neg_file)

train_ds = SequenceDataset(seq_pairs_train, train_labels, k)
test_ds  = SequenceDataset(seq_pairs_test,  test_labels,  k)

collate = get_collate_fn(input_shape)

train_loader = DataLoader(
    train_ds,
    batch_size=bs,
    shuffle=True,
    collate_fn=collate
)
test_loader = DataLoader(
    test_ds,
    batch_size=bs,
    shuffle=False,
    collate_fn=collate
)


model     = SiameseNetwork(input_shape).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

print(f"Training on {len(train_ds)} examples for {n_epochs} epochs…")
for epoch in range(1, n_epochs+1):
    t0 = time.time()
    model.train()
    running_loss = 0.0

    for x1, x2, y in train_loader:
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x1, x2)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * y.size(0)

    avg_loss = running_loss / len(train_ds)
    print(f"Epoch {epoch:02d}  loss={avg_loss:.4f}  time={(time.time()-t0):.1f}s")

model.eval()
all_labels = []
all_probs  = []

with torch.no_grad():
    for x1, x2, y in test_loader:
        x1, x2 = x1.to(device), x2.to(device)
        logits = model(x1, x2)
        probs = torch.softmax(logits, dim=1)[:,1].cpu().numpy()
        all_probs.extend(probs)
        all_labels.extend(y.numpy())

all_preds = (np.array(all_probs) >= 0.5).astype(int)

test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
test_roc_auc  = roc_auc_score(all_labels, all_probs)
print(f"\nTest balanced acc: {test_bal_acc:.3f}   ROC-AUC: {test_roc_auc:.3f}")
print(classification_report(all_labels, all_preds, digits=3))

cm = confusion_matrix(all_labels, all_preds)
plot_confusion(cm, title="Siamese AbAgIntPre ROC", out_path=OUTPUT_DIR/"test_confusion.png")

fpr, tpr, _ = roc_curve(all_labels, all_probs)
plot_roc(fpr, tpr, test_roc_auc, title="Siamese AbAgIntPre ROC", out_path=OUTPUT_DIR/"test_roc.png")

plt.figure(figsize=(4,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title(f"Siamese Confusion")
plt.tight_layout()
plt.savefig(OUTPUT_DIR/ f"Siamese_confusion.png", dpi=200)
plt.close()

fpr, tpr, _ = roc_curve(all_labels, all_probs)
plt.figure(); plt.plot(fpr, tpr, label=f'AUC={test_roc_auc:.3f}'); plt.plot([0,1],[0,1],'--');
plt.title(f'Siamese AbAgIntPre ROC'); plt.xlabel('FPR'); plt.ylabel('TPR'); plt.legend();
plt.savefig(OUTPUT_DIR/f"siamese_roc.png", dpi=200); plt.close()

print(f"\nSaved test_confusion.png and test_roc.png in {OUTPUT_DIR}")