In [None]:
import os
import subprocess
import csv
import random
from collections import defaultdict
from Bio import SeqIO
from sklearn.model_selection import train_test_split

In [9]:
def load_antigen_sequences(fasta_path):
    sequences = {}
    with open(fasta_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences[record.id] = str(record.seq)
    return sequences

fasta_path = Path.cwd() / "antigens.fasta"
ANTIGEN_SEQUENCES = load_antigen_sequences(fasta_path)

print(ANTIGEN_SEQUENCES)

{'SARS-CoV1': 'MFIFLLFLTLTSGSDLDRCTTFDDVQAPNYTQHTSSMRGVYYPDEIFRSDTLYLTQDLFLPFYSNVTGFHTINHTFGNPVIPFKDGIYFAATEKSNVVRGWVFGSTMNNKSQSVIIINNSTNVVIRACNFELCDNPFFAVSKPMGTQTHTMIFDNAFNCTFEYISDAFSLDVSEKSGNFKHLREFVFKNKDGFLYVYKGYQPIDVVRDLPSGFNTLKPIFKLPLGINITNFRAILTAFSPAQDIWGTSAAAYFVGYLKPTTFMLKYDENGTITDAVDCSQNPLAELKCSVKSFEIDKGIYQTSNFRVVPSGDVVRFPNITNLCPFGEVFNATKFPSVYAWERKKISNCVADYSVLYNSTFFSTFKCYGVSATKLNDLCFSNVYADSFVVKGDDVRQIAPGQTGVIADYNYKLPDDFMGCVLAWNTRNIDATSTGNYNYKYRYLRHGKLRPFERDISNVPFSPDGKPCTPPALNCYWPLNDYGFYTTTGIGYQPYRVVVLSFELLNAPATVCGPKLSTDLIKNQCVNFNFNGLTGTGVLTPSSKRFQPFQQFGRDVSDFTDSVRDPKTSEILDISPCSFGGVSVITPGTNASSEVAVLYQDVNCTDVSTAIHADQLTPAWRIYSTGNNVFQTQAGCLIGAEHVDTSYECDIPIGAGICASYHTVSLLRSTSQKSIVAYTMSLGADSSIAYSNNTIAIPTNFSISITTEVMPVSMAKTSVDCNMYICGDSTECANLLLQYGSFCTQLNRALSGIAAEQDRNTREVFAQVKQMYKTPTLKYFGGFNFSQILPDPLKPTKRSFIEDLLFNKVTLADAGFMKQYGECLGDINARDLICAQKFNGLTVLPPLLTDDMIAAYTAALVSGTATAGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKQIANQFNKAISQIQESLTTTSTALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQS

In [7]:
import os
import subprocess
import csv
import random
from collections import defaultdict
from Bio import SeqIO
from sklearn.model_selection import train_test_split

# =======================
# USER-ADJUSTABLE PARAMETERS
# =======================
# - Set WORKDIR to your working directory (default: current dir)
# - Set IDENTITY_CUTOFF and USE_CDHIT for CD-HIT clustering.
# - Set TEST_SIZE for train/test split proportion.
# - Set filter_substr to filter antigen IDs (None for all IDs).
WORKDIR = Path.cwd()
ANTIGEN_FASTA = WORKDIR / "antigens.fasta"
POS_TSV = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "positive dataset.txt"
NEG_TSV = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "negative dataset.txt"
PAIR_FASTA = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "pairs.fasta"
CDHIT_OUT = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "pairs_nr98.fasta"
POS_OUT = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "positive_balanced.txt"
NEG_OUT = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "negative_balanced.txt"
TRAIN_POS = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "train_pos.txt"
TRAIN_NEG = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "train_neg.txt"
TEST_POS = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "test_pos.txt"
TEST_NEG = WORKDIR / "AbAgIntPre" / "CoV-AbDab" / "test_neg.txt"

IDENTITY_CUTOFF  = 0.7
CDHIT_THREADS    = 4
TEST_SIZE        = 0.10 

# Set this to a substring of the antigen IDs you want to include,
# or None to include all.
filter_substr = 'SARS-CoV2'

# ──────────────────────────────────────────────────────────────────────────────
# Optional CD-HIT step
# ──────────────────────────────────────────────────────────────────────────────
# Set USE_CDHIT = False if you want to skip clustering and use raw pairs
USE_CDHIT = False

# ──────────────────────────────────────────────────────────────────────────────
# Utility: convert Windows path to WSL mount path
# ──────────────────────────────────────────────────────────────────────────────
def win_to_wsl(win_path):
    abspath = os.path.abspath(win_path)
    drive, rest = os.path.splitdrive(abspath)
    letter = drive.rstrip(":").lower()
    rest = rest.replace("\\", "/").lstrip("/")
    return f"/mnt/{letter}/{rest}"

# ──────────────────────────────────────────────────────────────────────────────
# Load antigen sequences from FASTA into a dict: ID → sequence
# ──────────────────────────────────────────────────────────────────────────────
ANTIGEN_SEQS = {
    record.id: str(record.seq)
    for record in SeqIO.parse(ANTIGEN_FASTA, "fasta")
}

# ──────────────────────────────────────────────────────────────────────────────
# Read original triplets TSV, look up antigen sequence by ID
# ──────────────────────────────────────────────────────────────────────────────
def load_triplets(pos_file, neg_file, filter_substr=None):
    trips = []
    for fn, label in [(pos_file, 1), (neg_file, 0)]:
        with open(fn) as f:
            reader = csv.reader(f, delimiter="\t")
            for row in reader:
                if len(row) < 3:
                    continue
                ag_id, heavy, light = row[0].strip(), row[1].strip(), row[2].strip()
                if filter_substr and filter_substr not in ag_id:
                    continue
                ag_seq = ANTIGEN_SEQS.get(ag_id)
                if ag_seq is None or len(ag_seq) < 50:
                    continue
                trips.append((ag_id, ag_seq, heavy, light, label))
    return trips

# ──────────────────────────────────────────────────────────────────────────────
# Write concatenated antigen+heavy+light FASTA for CD-HIT
# ──────────────────────────────────────────────────────────────────────────────
def write_pair_fasta(trips, fasta_path):
    with open(fasta_path, "w") as fa:
        for i, (_ag_id, ag_seq, heavy, light, _) in enumerate(trips):
            seq = ag_seq + heavy + light
            fa.write(f">Pair{i}\n{seq}\n")

# ──────────────────────────────────────────────────────────────────────────────
# Run CD-HIT on concatenated sequences
# ──────────────────────────────────────────────────────────────────────────────
def run_cdhit(in_fa, out_fa, identity, threads):
    word_size = 5 if identity >= 0.90 else 4
    cmd = [
        "cd-hit",
        "-i", win_to_wsl(in_fa),
        "-o", win_to_wsl(out_fa),
        "-c", f"{identity:.2f}",
        "-n", str(word_size),
        "-M", "0",
        "-T", str(threads),
    ]
    if os.name == "nt":
        cmd.insert(0, "wsl")
    print("Running CD-HIT:", " ".join(cmd))
    subprocess.run(cmd, check=True)

# ──────────────────────────────────────────────────────────────────────────────
# Parse the .clstr file to map PairID → clusterID
# ──────────────────────────────────────────────────────────────────────────────
def parse_pair_clusters(clstr_path):
    cluster_of = {}
    curr = None
    with open(clstr_path) as f:
        for line in f:
            if line.startswith(">Cluster"):
                curr = int(line.split()[1])
            else:
                parts = line.split(">")
                if len(parts) > 1:
                    pid = parts[1].split("...")[0]
                    cluster_of[pid] = curr
    return cluster_of

# ──────────────────────────────────────────────────────────────────────────────
# Balance so each CD-HIT cluster contributes equally
# ──────────────────────────────────────────────────────────────────────────────
def balance_by_cluster(trips, cluster_of):
    by_cluster = defaultdict(list)
    for idx, trip in enumerate(trips):
        pid = f"Pair{idx}"
        cid = cluster_of.get(pid)
        by_cluster[cid].append(trip)
    target = min(len(group) for group in by_cluster.values())
    balanced = []
    for group in by_cluster.values():
        balanced.extend(random.sample(group, target))
    random.shuffle(balanced)
    return balanced

# ──────────────────────────────────────────────────────────────────────────────
# Write out positive vs negative TSV (using antigen IDs)
# ──────────────────────────────────────────────────────────────────────────────
def write_output(trips, pos_path, neg_path):
    with open(pos_path, "w") as p, open(neg_path, "w") as n:
        for ag_id, _ag_seq, heavy, light, label in trips:
            line = f"{ag_id}\t{heavy}\t{light}\n"
            (p if label == 1 else n).write(line)

# ──────────────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    os.chdir(WORKDIR)

    all_trips = load_triplets(POS_TSV, NEG_TSV, filter_substr)
    print(f"Loaded {len(all_trips)} total pairs")

    if USE_CDHIT:
        write_pair_fasta(all_trips, PAIR_FASTA)
        run_cdhit(PAIR_FASTA, CDHIT_OUT, IDENTITY_CUTOFF, CDHIT_THREADS)
        clusters = parse_pair_clusters(f"{CDHIT_OUT}.clstr")
        balanced_clusters = balance_by_cluster(all_trips, clusters)
        print(f"Cluster-balanced down to {len(balanced_clusters)} pairs")
        write_output(balanced_clusters, POS_OUT, NEG_OUT)
    else:
        balanced_clusters = all_trips
        print(f"Skipping CD-HIT. Proceeding with {len(balanced_clusters)} raw pairs.")

    # Class-level 1:1 pos/neg balancing
    positives = [t for t in balanced_clusters if t[4] == 1]
    negatives = [t for t in balanced_clusters if t[4] == 0]
    cls_target = min(len(positives), len(negatives))
    positives_ds = random.sample(positives, cls_target)
    negatives_ds = random.sample(negatives, cls_target)
    balanced_class = positives_ds + negatives_ds
    random.shuffle(balanced_class)
    print(f"Class-balanced to {len(balanced_class)} pairs ({cls_target} pos + {cls_target} neg)")
    write_output(balanced_class, POS_OUT, NEG_OUT)

    # Stratified split into train/test only (80/20 by label)
    labels = [t[4] for t in balanced_class]
    train, test = train_test_split(
        balanced_class,
        test_size=TEST_SIZE,
        stratify=labels,
        random_state=42
    )
    print(f"Final splits → train: {len(train)}, test: {len(test)}")

    # Write train/test pos & neg
    write_output(train, TRAIN_POS, TRAIN_NEG)
    write_output(test,  TEST_POS,  TEST_NEG)
    print("Wrote train/test files:")
    print(f"  {TRAIN_POS}, {TRAIN_NEG}")
    print(f"  {TEST_POS},  {TEST_NEG}")

C:\Users\franc\Desktop\MastersYear3-Thesis\PredictiveModelAntibodyAntigen
Loaded 8058 total pairs
Running CD-HIT: wsl cd-hit -i /mnt/c/Users/franc/Desktop/MastersYear3-Thesis/PredictiveModelAntibodyAntigen/AbAgIntPre/CoV-AbDab/pairs.fasta -o /mnt/c/Users/franc/Desktop/MastersYear3-Thesis/PredictiveModelAntibodyAntigen/AbAgIntPre/CoV-AbDab/pairs_nr98.fasta -c 0.70 -n 4 -M 0 -T 4
Cluster-balanced down to 8058 pairs
Class-balanced to 1428 pairs (714 pos + 714 neg)
Final splits → train: 1285, test: 143
Wrote train/test files:
  C:\Users\franc\Desktop\MastersYear3-Thesis\PredictiveModelAntibodyAntigen\AbAgIntPre\CoV-AbDab\train_pos.txt, C:\Users\franc\Desktop\MastersYear3-Thesis\PredictiveModelAntibodyAntigen\AbAgIntPre\CoV-AbDab\train_neg.txt
  C:\Users\franc\Desktop\MastersYear3-Thesis\PredictiveModelAntibodyAntigen\AbAgIntPre\CoV-AbDab\test_pos.txt,  C:\Users\franc\Desktop\MastersYear3-Thesis\PredictiveModelAntibodyAntigen\AbAgIntPre\CoV-AbDab\test_neg.txt
