In [None]:
from Bio import SeqIO
from Bio.Align import PairwiseAligner
import csv
import random

# 文件路径
QUERY_FASTA = "sampled_1000_proteins.fasta"          # 100条查询蛋白序列
FULL_SWISSPROT = "uniprot_sprot.fasta"              # 完整Swiss-Prot数据库
OUTPUT_CSV = "non_homologous_pairs.csv"             # 输出负样本对

# 参数设置（可作为论文方法部分说明）
MAX_NEG_PER_QUERY = 5          # 每条查询序列生成5个负样本
IDENTITY_THRESHOLD = 25.0      # 身份度 < 25% 视为非同源
LENGTH_TOLERANCE = 0.3         # 序列长度差异不超过30%
MIN_LENGTH = 50                # 最小序列长度过滤
MAX_TRIALS = 300               # 每条查询最多尝试300次随机采样（足够找到5个负样本）

# 全局比对参数（标准Needleman-Wunsch）
aligner = PairwiseAligner(mode='global')
aligner.match_score = 2
aligner.mismatch_score = -1
aligner.gap_score = -5
aligner.extend_gap_score = -1

# 加载数据
print("Loading query sequences...")
query_records = [r for r in SeqIO.parse(QUERY_FASTA, "fasta") if len(r.seq) >= MIN_LENGTH]

print("Loading full Swiss-Prot database... (this may take a minute)")
full_records = [r for r in SeqIO.parse(FULL_SWISSPROT, "fasta") if len(r.seq) >= MIN_LENGTH]
print(f"Loaded {len(full_records)} sequences from Swiss-Prot.")

query_ids = {r.id for r in query_records}

# 生成负样本
with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["sentence1", "sentence2", "label"])

    for i, query_rec in enumerate(query_records, 1):
        q_seq = str(query_rec.seq)
        q_len = len(q_seq)
        q_id = query_rec.id

        print(f"[{i}/{len(query_records)}] Processing {q_id} (length={q_len})")

        found = 0
        trials = 0

        while found < MAX_NEG_PER_QUERY and trials < MAX_TRIALS:
            trials += 1
            cand_rec = random.choice(full_records)

            # 避免选到自身（概率极小）
            if cand_rec.id == q_id:
                continue

            c_seq = str(cand_rec.seq)
            c_len = len(c_seq)

            # 长度相似性过滤
            if abs(q_len - c_len) / ((q_len + c_len) / 2) > LENGTH_TOLERANCE:
                continue

            # 全局序列比对
            alignments = aligner.align(q_seq, c_seq)
            if len(alignments) == 0:
                continue
            alignment = alignments[0]

            # 计算身份度（不计gap）
            identical = sum(a == b and a != '-' and b != '-' 
                            for a, b in zip(str(alignment[0]), str(alignment[1])))
            aligned_len = len(str(alignment[0])) - str(alignment[0]).count('-') - str(alignment[1]).count('-')
            identity = 100.0 * identical / aligned_len if aligned_len > 0 else 0

            if identity < IDENTITY_THRESHOLD:
                writer.writerow([q_seq, c_seq, 0])
                found += 1

        print(f"  Found {found} negative pairs (after {trials} trials)")

print(f"\nDone! Negative samples saved to {OUTPUT_CSV}")

Loading query sequences...
Loading full Swiss-Prot database... (this may take a minute)
