In [1]:
# import os

# # 设置环境变量
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

# # 打印环境变量以确认设置成功
# print(os.environ.get('HF_ENDPOINT'))

import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
import os
import sys
import json
import torch
import numpy as np
import evaluate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, set_seed
from datasets import load_dataset
from tqdm import tqdm

In [3]:
seed = 42
lang = "en"
# 设置随机种子
set_seed(seed)

result = {}
result["seed"] = seed
result["type"] = "no_finetune_baseline"

In [None]:
# 初始化模型和分词器
model_checkpoint = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.pad_token = tokenizer.eos_token

# 加载模型 (预训练权重 + 随机分类头)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)
model.config.pad_token_id = model.config.eos_token_id

# 移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

In [None]:
# 定义两个专用的分词函数
def tokenize_short_function(example):
    return tokenizer(
        example["sentence1"],
        example["sentence2"],
        truncation=True,
        max_length=256,      # short 子集：完全无截断
        padding="max_length"
    )

def tokenize_full_function(example):
    return tokenizer(
        example["sentence1"],
        example["sentence2"],
        truncation=True,
        max_length=512,      # full 子集：覆盖 ~97%，最佳平衡
        padding="max_length"
    )

def plot_and_save_confusion_matrix(preds, labels, dataset_name="Protein Short"):
    """
    绘制混淆矩阵，并打印分类报告
    """
    # 1. 计算准确率
    acc = accuracy_score(labels, preds)
    print(f"[{dataset_name}] Raw Accuracy: {acc:.4f}")
    
    # 2. 检查翻转
    is_flipped = False
    if acc < 0.5:
        print(f">>> Detected Label Inversion (Acc < 0.5). Rectifying...")
        preds = 1 - preds
        acc = accuracy_score(labels, preds)
        print(f"[{dataset_name}] Rectified Accuracy: {acc:.4f}")
        is_flipped = True
    
    # ================= [新增] 打印详细分类报告 =================
    print(f"\n>>> Classification Report for {dataset_name}:")
    # target_names 对应 0 和 1 的含义
    report = classification_report(labels, preds, target_names=['Non-Homologous', 'Homologous'], digits=4)
    print(report)
    print("="*40)
    # ========================================================

    # 3. 计算混淆矩阵
    cm = confusion_matrix(labels, preds)
    
    # 4. 绘图
    sns.set_theme(style="white", font_scale=1.2)
    plt.figure(figsize=(6, 5))
    
    class_names = ['Non-Homologous', 'Homologous']
    
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d',
        cmap='Blues',
        cbar=False, 
        xticklabels=class_names,
        yticklabels=class_names,
        linewidths=1.5,
        linecolor='black',
        square=True
    )
    
    plt.ylabel('True Label', fontsize=12, fontweight='bold')
    plt.xlabel('Predicted Label', fontsize=12, fontweight='bold')
    
    plt.title(f'Confusion Matrix: Protein Homology Detection\nAccuracy: {acc:.2%}', 
              fontsize=14, pad=15, fontweight='bold')
    
    plt.tight_layout()
    
    filename = f"confusion_matrix_{dataset_name.replace(' ', '_')}_seed{seed}.png"
    plt.savefig(filename, dpi=300)
    print(f">>> Confusion Matrix saved to: {filename}")
    
    return acc

# 定义推理函数
def run_inference(test_dataset, batch_size=64):
    preds = []
    labels = []
    
    # disable=True 禁用进度条以保持输出纯净
    for i in tqdm(range(0, len(test_dataset), batch_size), desc="Predicting", disable=True):
        batch = test_dataset[i : i + batch_size]
        
        inputs = {
            "input_ids": torch.tensor(batch["input_ids"]).to(device),
            "attention_mask": torch.tensor(batch["attention_mask"]).to(device),
        }
        batch_labels = batch["label"] 

        with torch.no_grad():
            outputs = model(**inputs)
            batch_preds = torch.argmax(outputs.logits, axis=-1).cpu().numpy() 

        preds.extend(batch_preds)
        labels.extend(batch_labels)
        
    metric = evaluate.load("glue", "mrpc")
    plot_and_save_confusion_matrix(preds, labels)
    return metric.compute(predictions=preds, references=labels)

In [None]:
# ==========================================
# 测试集 1: protein_pair_short
# ==========================================
raw_datasets_short = load_dataset('dnagpt/biopaws', 'protein_pair_short')['train'].train_test_split(test_size=0.3, seed=seed)

# 直接分词
tokenized_raw_datasets_short = raw_datasets_short.map(tokenize_short_function, batched=True, num_proc=4)
ret_1 = run_inference(tokenized_raw_datasets_short["test"])
result["protein_pair_short"] = ret_1


In [None]:
# ==========================================
# 测试集 2: protein_pair_full (
# ==========================================
raw_datasets_full = load_dataset('dnagpt/biopaws', 'protein_pair_full')['train'].train_test_split(test_size=0.3, seed=seed)

# 直接分词 (去除了 flip_labels 以保持与基线脚本一致)
tokenized_raw_datasets_full = raw_datasets_full.map(tokenize_full_function, batched=True, num_proc=4)
ret_2 = run_inference(tokenized_raw_datasets_full["test"])
result["protein_pair_full"] = ret_2

In [None]:
# ==========================================
# 输出结果
# ==========================================
print(json.dumps(result))