In [1]:
# import os

# # 设置环境变量
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

# # 打印环境变量以确认设置成功
# print(os.environ.get('HF_ENDPOINT'))

import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
import json
import random
import re
import os
from datasets import load_dataset
from openai import OpenAI
from sklearn.metrics import accuracy_score, classification_report

In [3]:
# ==========================================
# 1. 配置 API  https://bailian.console.aliyun.com/
# ==========================================
API_KEY = "sk-"  # 注意：在生产环境中请勿硬编码，建议使用 os.getenv
MODEL_ID = "qwen3-max"

client = OpenAI(
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    api_key=API_KEY,
)

In [4]:
import json
import random
from datasets import load_dataset
import traceback

# ==========================================
# 2. 数据准备 (加载、切分 Few-Shot 与 测试集)
# ==========================================
print("Loading dataset...")
try:
    # 1. 加载本地数据集
    # local_dataset_path = "./dna_transcription_factor_prediction"
    # dataset = load_dataset(local_dataset_path, trust_remote_code=True)
    dataset = load_dataset('dnagpt/dna_promoter_300')
    # 获取训练集
    ds = dataset['train']
    
    print(f"Dataset columns: {ds.column_names}") 

    # 2. 按标签分离数据
    data_label_0 = [item for item in ds if item['label'] == 0]
    data_label_1 = [item for item in ds if item['label'] == 1]
    
    # -------------------------------------------------
    # 3. 核心步骤：构建 Few-Shot 示例 (各取前5个)
    # -------------------------------------------------
    shot_num = 5  # 5-Shot
    
    # 拿走前5个作为“教科书”示例
    shots_pos = data_label_1[:shot_num]
    shots_neg = data_label_0[:shot_num]
    
    # -------------------------------------------------
    # 4. 核心步骤：构建测试集 (从剩下的数据里采)
    # -------------------------------------------------
    # 剩下的作为“候选池”
    candidates_pos = data_label_1[shot_num:]
    candidates_neg = data_label_0[shot_num:]
    
    # 设定测试集大小 (例如各测 50 个，共 100 个)
    # 注意：Few-Shot 会增加 Prompt 长度，测试集太大可能导致 Token 超出，建议 50-80 左右
    test_sample_num = 50 
    
    # 采样
    real_test_pos = random.sample(candidates_pos, min(len(candidates_pos), test_sample_num))
    real_test_neg = random.sample(candidates_neg, min(len(candidates_neg), test_sample_num))
    
    # 合并并打乱测试集
    combined_test_data = real_test_pos + real_test_neg
    random.seed(42)
    random.shuffle(combined_test_data)
    
    print(f"Few-Shot Examples: {len(shots_pos)} Pos / {len(shots_neg)} Neg")
    print(f"Test Data Prepared: {len(combined_test_data)} sequences (Balanced)")

except Exception as e:
    print(f"Error loading dataset: {e}")
    traceback.print_exc()
    exit()

# ==========================================
# 5. 格式化 Few-Shot 文本块 (供 Prompt 使用)
# ==========================================
def get_seq_content(item):
    """兼容不同列名的提取函数"""
    return item.get('sequence', item.get('sentence', item.get('seq', '')))

def format_few_shot_examples(pos_list, neg_list):
    """将示例列表转化为 Prompt 里的字符串"""
    text = "Reference Examples (Ground Truth):\n"
    
    text += "\n--- Class 1 (Target Pattern) ---\n"
    for i, item in enumerate(pos_list, 1):
        text += f"Example_Pos_{i}: {get_seq_content(item)}\n"
        
    text += "\n--- Class 0 (Background Noise) ---\n"
    for i, item in enumerate(neg_list, 1):
        text += f"Example_Neg_{i}: {get_seq_content(item)}\n"
        
    return text

# 生成示例字符串变量，稍后塞进 user_prompt
examples_text_block = format_few_shot_examples(shots_pos, shots_neg)

# ==========================================
# 6. 构建测试数据的 JSON List
# ==========================================
prompt_data_list = []
id_to_ground_truth = {}

for idx, item in enumerate(combined_test_data, 1):
    seq_content = get_seq_content(item)
    
    prompt_data_list.append({
        "id": idx,
        "sequence": seq_content
    })
    
    id_to_ground_truth[idx] = item['label']

print("Prompt data constructed.")

Loading dataset...
Dataset columns: ['sequence', 'label']
Few-Shot Examples: 5 Pos / 5 Neg
Test Data Prepared: 100 sequences (Balanced)
Prompt data constructed.


In [5]:
# ==========================================
# 3. 构建 Prompt (5-Shot 通用模式识别版)
# ==========================================

# System Prompt: 设定为“小样本学习”专家
# 核心逻辑：Observation (观察示例) -> Abstraction (提取规律) -> Application (应用到新数据)
system_prompt = """You are an advanced Pattern Recognition AI capable of "In-Context Learning" from biological sequences.

1. The Task:
You are provided with a set of "Reference Examples" (Ground Truth) for two sequence categories:
* **Class 1 (Target)**: Sequences that share a specific hidden structure, motif, or statistical property.
* **Class 0 (Background)**: Sequences that act as noise or lack the specific features of Class 1.

2. Your Strategy:
* **Step 1: Calibration**. Analyze the provided Reference Examples. Compare Class 1 vs. Class 0 to identify the distinguishing features (e.g., specific substrings, GC-content, complexity, or repetition).
* **Step 2: Inference**. Apply this learned distinction to the "Test Batch".
* **Step 3: Decision**. For each new sequence, determine if it is more similar to the Class 1 references or the Class 0 references.

3. Output Requirements:
* Return a RAW JSON object containing the results.
* Format: `[{"id": 1, "prediction": 1}, {"id": 2, "prediction": 0}, ...]`
"""

# User Prompt: 拼接 示例字符串 + 测试数据
# ⚠️ 注意：这里的 variable 要用上一段代码生成的 examples_text_block
user_prompt = f"""{examples_text_block}

=========================================
**INSTRUCTION:**
The examples above define the rules.
Now, classify the following NEW sequences based on the patterns observed in the "Reference Examples".

**Test Batch Data:**
{json.dumps(prompt_data_list, indent=2)}
"""

In [6]:
# ==========================================
# 4. 调用 Volcengine API
# ==========================================
print("-" * 30)
print(f"Calling Volcengine Model: {MODEL_ID}...")

try:
    response = client.chat.completions.create(
        model=MODEL_ID,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.1, # 低温以保证输出格式稳定
        top_p=0.9,
    )
    
    full_content = response.choices[0].message.content.strip()
    print("Response received.")
    # 打印前200个字符用于调试
    print(f"Response snippet: {full_content[:200]}...")

except Exception as e:
    print(f"API Call Failed: {e}")
    full_content = ""

------------------------------
Calling Volcengine Model: qwen3-max...
Response received.
Response snippet: [{"id": 1, "prediction": 1}, {"id": 2, "prediction": 1}, {"id": 3, "prediction": 0}, {"id": 4, "prediction": 1}, {"id": 5, "prediction": 0}, {"id": 6, "prediction": 0}, {"id": 7, "prediction": 1}, {"i...


In [7]:
# ==========================================
# 5. 解析结果与评估 (自动处理标签反转版)
# ==========================================
import re
import json
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

def parse_llm_json(text):
    """提取并解析 JSON"""
    try:
        code_block = re.search(r"```json\s*(\[.*?\])\s*```", text, re.DOTALL)
        if code_block:
            return json.loads(code_block.group(1))
        match = re.search(r"\[.*\]", text, re.DOTALL)
        if match:
            return json.loads(match.group(0))
        return json.loads(text)
    except:
        return []

# 解析
predictions_list = parse_llm_json(full_content)

y_true = []
y_pred_raw = [] # 原始预测

print("-" * 30)
if not predictions_list:
    print("Failed to parse JSON.")
else:
    print(f"Parsed {len(predictions_list)} predictions.")
    
    for item in predictions_list:
        p_id = item.get('id')
        p_val = item.get('prediction')
        
        if p_id in id_to_ground_truth and p_val in [0, 1]:
            y_true.append(id_to_ground_truth[p_id])
            y_pred_raw.append(int(p_val))

    if y_true:
        # 1. 计算原始准确率
        acc_raw = accuracy_score(y_true, y_pred_raw)
        print(f"\n[Original] Accuracy: {acc_raw:.2%}")
        
        # 2. 检查是否需要反转 (Anti-correlation Check)
        final_y_pred = y_pred_raw
        is_flipped = False
        
        if acc_raw < 0.5:
            print("\n⚠️ Detected Label Flipping (Accuracy < 50%)!")
            print("The model found the pattern but swapped the labels.")
            print("Inverting predictions (0->1, 1->0)...")
            
            # 执行反转：0变1，1变0
            final_y_pred = [1 - y for y in y_pred_raw]
            is_flipped = True
        
        # 3. 计算最终指标
        acc_final = accuracy_score(y_true, final_y_pred)
        print(f"\n[Corrected] Final Accuracy: {acc_final:.2%}")
        
        print("\nClassification Report (After Correction):")
        # 这里的 target_names 顺序固定是 [0, 1]
        print(classification_report(y_true, final_y_pred, target_names=["Unbound (0)", "Bound (1)"]))
        
        # 打印混淆矩阵看一眼
        print("\nConfusion Matrix:")
        print(confusion_matrix(y_true, final_y_pred))

        result_log = {
            "accuracy": acc_final,
            "flipped": is_flipped,
            "predictions": predictions_list
        }
    else:
        print("No valid matching IDs found.")

------------------------------
Parsed 100 predictions.

[Original] Accuracy: 78.00%

[Corrected] Final Accuracy: 78.00%

Classification Report (After Correction):
              precision    recall  f1-score   support

 Unbound (0)       0.73      0.88      0.80        50
   Bound (1)       0.85      0.68      0.76        50

    accuracy                           0.78       100
   macro avg       0.79      0.78      0.78       100
weighted avg       0.79      0.78      0.78       100


Confusion Matrix:
[[44  6]
 [16 34]]
