In [1]:
import os
import json
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# ================= 配置区域 =================
MODEL_ID = "google/gemma-2-9b-it"
OUTPUT_FILE = "ekar_gemma_extraction.jsonl"

# ================= 模型初始化 (针对 96GB 显存优化) =================
print(f"正在加载模型: {MODEL_ID} ...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)



model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cuda",  # 如果是单张大卡用 "cuda"，如果是多卡且想自动分配用 "auto"
    torch_dtype=torch.bfloat16, 
)
print(">>> 模型已以 BF16 加载 ")


# ================= Prompt 构造函数 (保持不变) =================
def construct_prompt(item):
    question = item['question']
    choices = item['choices'] 
    answer_key = item['answerKey']
    
    # 处理解析
    raw_explanation = item.get('explanation', [])
    if isinstance(raw_explanation, list):
        explanation_str = "\n".join(raw_explanation)
    else:
        explanation_str = str(raw_explanation)

    # 提取选项
    choice_texts = choices['text']
    correct_idx = ord(answer_key) - ord('A')
    correct_choice_text = choice_texts[correct_idx] if 0 <= correct_idx < 4 else "未知"

    # Prompt 内容
    prompt_content = f"""以下是一个处理范例：
输入：
题目词对：生病:医药
选项：A.空虚：信仰 B.糊涂：明白 C.雷雨：大风 D.难过：高兴
正确答案：A
参考解析：
[
"改善“生病”的状况可以用“医药”。",
"改善“空虚”的状况可以借助“信仰”。",
"“糊涂”和“明白”的意思相反，“明白”不能改善“糊涂”的状况。",
"“雷雨”和“大风”都属于自然现象，“雷雨”的状况不可以由“大风”改善。",
"“难过”和“高兴”的意思相反，“难过”的状况不一定可以由“高兴”改善。"
]
输出：
{{
"question_logic": "医药是用来治疗生病状态的事物",
"answer_logic": "信仰是用来填补空虚状态的事物",
"abstract_relation": "<词2>是用来改善或解决<词1>所代表的负面状态的手段/事物"
}}

请分析以下数据并提取关系：

**题目词对**: {question}
**选项**:
A: {choice_texts[0]}
B: {choice_texts[1]}
C: {choice_texts[2]}
D: {choice_texts[3]}

**正确答案**: {answer_key} (即选项 {correct_choice_text})
**参考解析**: 
{explanation_str}

### 思考步骤

1. 分析【题目词对】的逻辑关系。
2. 分析【正确选项】的逻辑关系，寻找它与题目词对的共同点（这就是最佳颗粒度）。
3. 检查这个共同关系是否适用于错误选项。如果适用，说明颗粒度太粗，需要增加限定条件；如果不适用，则保留。

### 输出格式

请直接输出一个 JSON 对象，不要包含其他废话，格式如下：
{{
"question_logic": "简述题目词对的具体关系",
"answer_logic": "简述正确答案的具体关系",
"abstract_relation": "使用<词1>和<词2>生成的标准化关系模板"
}}"""
    return prompt_content

# ================= 推理函数 =================
def generate_response(prompt_text):
    messages = [{"role": "user", "content": prompt_text}]
    
    input_ids = tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt", 
        add_generation_prompt=True
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<end_of_turn>")
    ]

    # Gemma-2 在 BF16 下性能极佳
    outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        eos_token_id=terminators,
        do_sample=False, # 贪婪解码，保证逻辑提取的一致性
    )
    
    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response

def process_dataset():
    print("正在加载数据集...")
    try:
        dataset = load_dataset("jiangjiechen/ekar_chinese")
    except Exception as e:
        print(f"数据集加载失败: {e}")
        return

    all_data = []
    for split in ['train', 'validation']:
        for item in dataset[split]:
            all_data.append(item)
    
    print(f"总数据量: {len(all_data)} 条")

    # 断点续传
    processed_ids = set()
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    record = json.loads(line)
                    if 'id' in record: 
                        processed_ids.add(record['id'])
                except: 
                    pass
    print(f"跳过已处理: {len(processed_ids)} 条")

    # 主循环
    with open(OUTPUT_FILE, 'a', encoding='utf-8') as f_out:
        for item in tqdm(all_data, desc="Gemma 96GB High-Res Extract"):
            item_id = item['id']
            if item_id in processed_ids:
                continue

            prompt = construct_prompt(item)
            
            try:
                response_text = generate_response(prompt)
                
                # 清洗 JSON
                clean_json_str = response_text.replace("```json", "").replace("```", "").strip()
                try:
                    model_output = json.loads(clean_json_str)
                except json.JSONDecodeError:
                    model_output = {"error": "JSON_DECODE_FAIL", "raw_content": response_text}

                final_record = {
                    "id": item_id,
                    "original_question": item['question'],
                    "model_extraction": model_output
                }
                
                f_out.write(json.dumps(final_record, ensure_ascii=False) + "\n")
                f_out.flush()

            except Exception as e:
                print(f"\n[Fail] ID {item_id}: {e}")
                continue

if __name__ == "__main__":
    process_dataset()

正在加载模型: google/gemma-2-9b-it ...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:44<00:00, 11.07s/it]


>>> 模型已以 BF16 加载 
正在加载数据集...
总数据量: 1320 条
跳过已处理: 0 条


Gemma 96GB High-Res Extract: 100%|██████████| 1320/1320 [1:21:30<00:00,  3.71s/it]


In [3]:
INPUT_FILE = "ekar_gemma_extraction.jsonl"
OUTPUT_FILE = "relations_for_analysis.txt"

def export_relations():
    count = 0
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f_out:
        with open(INPUT_FILE, 'r', encoding='utf-8') as f_in:
            for line in f_in:
                try:
                    record = json.loads(line)
                    # 获取模型提取的抽象关系
                    # 注意：要兼容一下你之前代码里可能的字段结构
                    if 'model_extraction' in record and 'abstract_relation' in record['model_extraction']:
                        relation = record['model_extraction']['abstract_relation']
                        question = record.get('original_question', '未知题目')
                        
                        # 格式：[行号] (题目) 关系描述
                        f_out.write(f"[{count+1}] ({question}) {relation}\n")
                        count += 1
                except Exception as e:
                    continue
    
    print(f"✅ 已成功导出 {count} 条关系描述到 {OUTPUT_FILE}")
    print("现在你可以把这个文件直接传给 Google AI Studio 了。")

if __name__ == "__main__":
    export_relations()

✅ 已成功导出 1319 条关系描述到 relations_for_analysis.txt
现在你可以把这个文件直接传给 Google AI Studio 了。


In [None]:
final_relation_tags = [
    # --- 基础语义逻辑 ---
    "近义/同一",      # Synonyms/Identity (马铃薯:土豆)
    "反义/对立",      # Antonyms (高:矮) - 注意：一般的反义词放这里
    "包含/种属",      # Hyponymy (苹果:水果)
    "组成/整体",      # Meronymy (轮胎:汽车)
    "并列/同类",      # Coordinate (长江:黄河)
    "属性/特征",      # Attribute (糖:甜)
    "象征/比喻",      # Symbolism (鸽子:和平)

    # --- 实体交互逻辑 (物理/功能) ---
    "主体-动作",      # Agent-Action (鸟:飞)
    "主体-产物",      # Agent-Product (作家:书)
    "材料-成品",      # Material-Product (木头:桌子)
    "工具-功能",      # Tool-Function (刀:切)
    "位置/空间",      # Location (轮船:海)
    "顺序/过程",      # Sequence (报名:考试)
    "因果/依赖",      # Causality (缺水:干旱)

    # --- 社会/人物逻辑 (从"社会关系"拆分) ---
    "亲属关系",       # Kinship (舅舅:外甥) - 只有血缘/婚姻
    "师生传承",       # Mentorship (孔子:颜回) - 强调教导
    "职业-对象",      # Professional-Client (医生:病人) - 强调服务
    "等级/排序",      # Hierarchy (经理:职员) - 强调地位差
]

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import csv
import re
from tqdm import tqdm

# ================= 配置区域 =================
MODEL_ID = "google/gemma-2-9b-it"
INPUT_FILE = 'relations_for_analysis.txt'
OUTPUT_FILE = 'labeled_relations.csv'
# --- 新增：批处理大小 ---
BATCH_SIZE = 32

# ================= 模型初始化 =================
print(f"正在加载模型: {MODEL_ID} ...")

# --- 修改点：为 Tokenizer 添加 padding_side 以支持批处理 ---
# 对于解码器-only 的模型 (如 Gemma), padding 需要在左侧
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side='left')
# 如果分词器没有默认的 pad token, 手动设置为 eos_token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)
print(">>> 模型已以 BF16 加载完成")

# ================= 功能函数 =================

def parse_line(line):
    """
    解析原始文本行。
    格式: [1] (军队:命令) <词2>是用来控制或引导<词1>的行为
    """
    pattern = r"^\[(\d+)\] \((.*?)\) (.*)$"
    match = re.match(pattern, line.strip())

    if match:
        return {
            "id": match.group(1),
            "terms": match.group(2),
            "description": match.group(3)
        }
    return None


# --- 重大修改：函数重构为批量处理 ---
def get_labels_from_gemma_batch(batch_items):
    """
    构造 Prompt 并获取模型对一个批次的分类结果 (中文 Prompt 版本)
    """
    tags_str = "\n".join([f"- {tag}" for tag in final_relation_tags])
    all_messages = []

    # 为批次中的每个项目创建 Prompt
    for item in batch_items:
        user_content = f"""你是一位严格的知识图谱数据分类专家。
你的任务是根据提供的描述，将词项之间的语义关系归类到预定义的候选列表中。

### 候选分类列表：
{tags_str}

### 待分类数据：
- 词项：({item['terms']})
- 逻辑描述：{item['description']}

### 指令：
请分析“逻辑描述”所表达的关系。从“候选分类列表”中选择最恰当的一个类别。
请严格仅输出列表中的类别名称。不要输出任何其他文本、标点或解释。
"""
        messages = [{"role": "user", "content": user_content}]
        all_messages.append(messages)

    # 使用 Gemma 的标准对话模板，并进行 padding
    # tokenizer 可以一次性处理一个批次的对话列表
    input_ids = tokenizer.apply_chat_template(
        all_messages,
        return_tensors="pt",
        add_generation_prompt=True,
        padding=True, # 启用填充
        truncation=True # 启用截断
    ).to(model.device)

    # 生成配置: 贪婪解码(do_sample=False)以获得最稳定的分类
    outputs = model.generate(
        input_ids,
        max_new_tokens=32,  # 标签很短，不需要生成很长
        do_sample=False,    # 关闭随机采样
        temperature=None,   # 贪婪模式不需要 temp
        top_p=None
    )

    # 解码 (只取生成的新 token)
    # --- 修改点：使用 batch_decode ---
    input_ids_len = input_ids.shape[1]
    generated_tokens = outputs[:, input_ids_len:]
    responses = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

    # 返回一个包含所有原始输出的列表
    return [res.strip() for res in responses]


def clean_label(raw_output):
    """
    清洗模型输出，去除可能的标点或多余空格
    """
    # 移除可能出现的 "类别：" 或 "Category:" 前缀
    clean_text = raw_output.replace("Category:", "").replace("类别：", "").strip()

    # 移除可能的末尾句号
    if clean_text.endswith("。") or clean_text.endswith("."):
        clean_text = clean_text[:-1]

    # 验证是否在列表中
    if clean_text in final_relation_tags:
        return clean_text

    # 模糊匹配修复
    for tag in final_relation_tags:
        if tag in clean_text:
            return tag

    return "Uncategorized"

# ================= 主程序 =================

def main():
    # 1. 读取数据
    data_items = []
    try:
        with open(INPUT_FILE, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            for line in lines:
                if line.strip():
                    parsed = parse_line(line)
                    if parsed:
                        data_items.append(parsed)
        print(f"成功读取 {len(data_items)} 条数据。")
    except FileNotFoundError:
        print(f"错误: 找不到文件 {INPUT_FILE}")
        return

    # 2. 处理并写入
    print(f"开始推理... (Batch Size: {BATCH_SIZE})")

    # 使用 'w' 模式写入 CSV
    with open(OUTPUT_FILE, 'w', newline='', encoding='utf-8-sig') as csvfile:
        fieldnames = ['id', 'terms', 'description', 'label', 'raw_output']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        # --- 修改点：按批次循环处理 ---
        # 使用 tqdm 显示批次的进度
        for i in tqdm(range(0, len(data_items), BATCH_SIZE), desc="Processing Batches"):
            # 获取当前批次的数据
            batch = data_items[i:i + BATCH_SIZE]
            
            try:
                # 批量获取模型的原始输出
                raw_outputs = get_labels_from_gemma_batch(batch)

                # 遍历当前批次的结果并写入文件
                for item, raw_output in zip(batch, raw_outputs):
                    # 清洗标签
                    label = clean_label(raw_output)

                    # 写入
                    writer.writerow({
                        'id': item['id'],
                        'terms': item['terms'],
                        'description': item['description'],
                        'label': label,
                        'raw_output': raw_output
                    })
            except Exception as e:
                print(f"处理批次 {i//BATCH_SIZE + 1} 时出错 (ID 从 {batch[0]['id']} 开始): {e}")
                # 你可以选择在这里为批次中的所有项目写入错误信息，或者跳过
                for item in batch:
                     writer.writerow({
                        'id': item['id'],
                        'terms': item['terms'],
                        'description': item['description'],
                        'label': 'Error',
                        'raw_output': str(e)
                    })
                continue

    print(f"\n任务完成！结果已保存至: {OUTPUT_FILE}")

if __name__ == "__main__":
    main()

正在加载模型: google/gemma-2-9b-it ...


Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.88s/it]


>>> 模型已以 BF16 加载完成
成功读取 1319 条数据。
开始推理... (Batch Size: 32)


Processing Batches:   0%|          | 0/42 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Processing Batches: 100%|██████████| 42/42 [00:46<00:00,  1.10s/it]


任务完成！结果已保存至: labeled_relations.csv





In [2]:
import json
import os
from dotenv import load_dotenv
from datasets import load_dataset

def main():
    # 1. 加载环境变量
    load_dotenv()
    hf_token = os.getenv("HF_TOKEN")
    
    if not hf_token:
        print("警告: 未在 .env 文件或环境变量中找到 HF_TOKEN。")
        print("尝试以匿名模式下载（如果数据集是公开的通常没问题）...")
    else:
        print("已检测到 HF_TOKEN，正在使用身份验证...")

    # 数据集 ID
    DATASET_ID = "jiangjiechen/ekar_chinese"
    OUTPUT_FILE = "ekar_test.json"

    print(f"正在从 Hugging Face 下载 {DATASET_ID} 的 test 集...")

    try:
        # 2. 加载数据集
        # split="test" 指定只下载测试集
        # token=hf_token 用于身份验证
        dataset = load_dataset(
            DATASET_ID, 
            split="test", 
            token=hf_token,
            trust_remote_code=True # 允许执行数据集仓库中的加载脚本
        )
    except Exception as e:
        print(f"下载失败: {e}")
        print("请检查你的 HF_TOKEN 是否正确，以及是否有权限访问该数据集。")
        return

    print(f"下载成功！共获取 {len(dataset)} 条数据。")

    # 3. 转换为 Python 列表并保存为 JSON
    # 这步是为了配合你之前的脚本，将其保存为标准的 JSON 数组格式
    print(f"正在转换并保存到 {OUTPUT_FILE} ...")
    
    data_list = []
    # 遍历 dataset 对象将其转为普通 dict
    for item in dataset:
        data_list.append(item)

    with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
        json.dump(data_list, f, ensure_ascii=False, indent=2)

    print("完成！文件已保存，可以运行提取脚本了。")

if __name__ == "__main__":
    main()

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'jiangjiechen/ekar_chinese' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


已检测到 HF_TOKEN，正在使用身份验证...
正在从 Hugging Face 下载 jiangjiechen/ekar_chinese 的 test 集...
下载成功！共获取 335 条数据。
正在转换并保存到 ekar_test.json ...
完成！文件已保存，可以运行提取脚本了。


In [4]:
import json
import os

# ================= 配置区域 =================
RESULT_FILE = 'analysis_consistency_result.json'  # 分析结果文件
LABEL_MAP_FILE = 'label_map.json'                 # 标签映射文件
EKAR_FILE = 'ekar_test.json'                      # 原始数据文件（包含选项文本）
OUTPUT_FILE = 'output_word_pairs.txt'             # 输出文件
# ===========================================

def load_json(filename):
    """读取JSON文件，带错误处理"""
    if not os.path.exists(filename):
        print(f"错误: 找不到文件 {filename}")
        return None
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"错误: 读取 {filename} 失败。详情: {e}")
        return None

def split_pair(text):
    """
    将 '词1:词2' 格式的字符串分割为元组。
    支持中文冒号和英文冒号。
    """
    if not isinstance(text, str):
        return ("Unknown", "Unknown")
    
    # 统一替换中文冒号
    text = text.replace("：", ":")
    
    parts = text.split(":")
    if len(parts) >= 2:
        return (parts[0].strip(), parts[1].strip())
    else:
        # 如果无法分割，保留原文本作为词1，词2为空
        return (text.strip(), "")

def build_ekar_lookup(ekar_data):
    """
    将ekar列表转换为以id为key的字典，方便查询
    返回结构:
    {
        "id": {
            "q_words": ("生病", "医药"),
            "opt_words_list": [("词A", "词B"), ("词C", "词D")...]
        }
    }
    """
    lookup = {}
    # 兼容 ekar_data 可能是 list 或者 dict (lines) 的情况
    if isinstance(ekar_data, dict):
        # 有些数据集是一行一个json对象的lines格式，这里假设是标准list
        pass 
    
    data_list = ekar_data if isinstance(ekar_data, list) else []
    
    for item in data_list:
        data_id = item.get("id")
        if not data_id:
            continue
            
        # 1. 提取题干词对
        # 假设字段可能是 'question' 或 'question_text'
        q_text = item.get("question", item.get("question_text", ""))
        q_pair = split_pair(q_text)
        
        # 2. 提取选项词对
        # 假设字段是 'choices' 或 'options'
        choices = item.get("choices", item.get("options", []))
        # 如果是字典形式 {"A": "...", "B": "..."} 转为列表排序
        if isinstance(choices, dict):
            sorted_keys = sorted(choices.keys())
            choices_list = [choices[k] for k in sorted_keys]
        else:
            choices_list = choices
            
        opt_pairs = [split_pair(c) for c in choices_list]
        
        lookup[data_id] = {
            "q_words": q_pair,
            "opt_words_list": opt_pairs
        }
    return lookup

def main():
    # 1. 加载 Label Map
    print("正在加载标签映射...")
    label_map_raw = load_json(LABEL_MAP_FILE)
    if not label_map_raw: return
    # 建立 反向映射: "主体-动作" -> 2
    name_to_id = {v: int(k) for k, v in label_map_raw.items()}

    # 2. 加载 EKAR 原始数据 (为了获取词语)
    print("正在加载 EKAR 原始数据...")
    ekar_data = load_json(EKAR_FILE)
    if not ekar_data: 
        print("请确认 ekar_test.json 文件是否存在且格式正确。")
        return
    
    # 构建查询字典
    ekar_lookup = build_ekar_lookup(ekar_data)
    print(f"已加载 {len(ekar_lookup)} 条原始题目数据。")

    # 3. 加载分析结果 (为了获取预测标签)
    print("正在加载分析结果...")
    analysis_data = load_json(RESULT_FILE)
    if not analysis_data: return
    
    if isinstance(analysis_data, dict):
        analysis_data = [analysis_data]

    results_to_write = []
    
    # 4. 核心处理循环
    print("正在匹配数据...")
    for entry in analysis_data:
        uid = entry.get("id")
        
        # 查找对应的词语数据
        if uid not in ekar_lookup:
            print(f"跳过: ID {uid} 在 {EKAR_FILE} 中未找到。")
            continue
            
        word_data = ekar_lookup[uid]
        classifier_info = entry.get("classifier_analysis", {})
        
        # --- 处理题干 (Question) ---
        q_label_name = classifier_info.get("q_label")
        q_grid_index = name_to_id.get(q_label_name, 0) # 找不到默认为0 (Uncategorized)
        
        q_word_a, q_word_b = word_data["q_words"]
        
        # 添加题干数据: (词A, 词B, 标签ID)
        results_to_write.append((q_word_a, q_word_b, q_grid_index))
        
        # --- 处理选项 (Options) ---
        opt_label_names = classifier_info.get("opt_labels", [])
        opt_word_pairs = word_data["opt_words_list"]
        
        # 确保选项数量对齐（通常是4个）
        for idx, label_name in enumerate(opt_label_names):
            if idx < len(opt_word_pairs):
                # 获取该选项对应的词对
                o_word_a, o_word_b = opt_word_pairs[idx]
                
                # 获取该选项对应的标签ID
                o_grid_index = name_to_id.get(label_name, 0)
                
                # 添加选项数据
                results_to_write.append((o_word_a, o_word_b, o_grid_index))

    # 5. 写入文件
    print(f"正在写入 {len(results_to_write)} 条数据到 {OUTPUT_FILE} ...")
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        for item in results_to_write:
            # item 已经是 (Obj_A, Obj_B, grid_index) 格式
            f.write(f"{str(item)}\n")
            
    print("处理完成！")

if __name__ == "__main__":
    main()

正在加载标签映射...
正在加载 EKAR 原始数据...
已加载 335 条原始题目数据。
正在加载分析结果...
正在匹配数据...
正在写入 1005 条数据到 output_word_pairs.txt ...
处理完成！


In [5]:
import os
import json
import csv

# 根目录（按你的路径）
BASE_DIR = r"d:\VH_analogy\20260201_analogyreason"

analysis_path = os.path.join(BASE_DIR, "analysis_consistency_result.json")
label_map_path = os.path.join(BASE_DIR, "label_map.json")
ekar_path = os.path.join(BASE_DIR, "ekar_test.json")
output_path = os.path.join(BASE_DIR, "obj_grid_index.csv")


def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def build_label_index(label_map):
    """
    label_map: 形如 {"0": "Uncategorized", "1": "主体-产物", ...}
    返回: {"主体-产物": 1, ...}，自动 strip() 掉前后空格
    """
    mapping = {}
    for k, v in label_map.items():
        if not isinstance(v, str):
            continue
        mapping[v.strip()] = int(k)
    return mapping


def extract_pair(text):
    """
    从形如 "校长:老师" 或 "青年:村官:基层锻炼" 的字符串中
    提取前两个词，作为 (Obj_A, Obj_B)。

    如果不能提取到两个词，返回 (None, None)。
    """
    if not isinstance(text, str):
        return None, None
    parts = [p.strip() for p in text.split(":") if p.strip()]
    if len(parts) < 2:
        return None, None
    return parts[0], parts[1]


def main():
    # 读取三个 JSON
    analysis_data = load_json(analysis_path)
    label_map = load_json(label_map_path)
    ekar_data = load_json(ekar_path)

    # 构建 label -> index 映射
    label_to_index = build_label_index(label_map)

    # 把 ekar_test 按 id 建索引
    ekar_by_id = {item.get("id"): item for item in ekar_data if item.get("id")}

    rows = []

    # 统计信息
    total_items = len(analysis_data)
    missing_ekar_count = 0
    missing_q_label_index = 0
    missing_opt_label_index = 0
    bad_question_text_count = 0
    bad_option_text_count = 0

    for item in analysis_data:
        qid = item.get("id")
        if not qid:
            continue

        ekar_item = ekar_by_id.get(qid)
        if ekar_item is None:
            missing_ekar_count += 1

        classifier = item.get("classifier_analysis", {}) or {}
        q_label_raw = classifier.get("q_label")
        opt_labels_raw = classifier.get("opt_labels", []) or []

        # 规范化 label
        q_label = q_label_raw.strip() if isinstance(q_label_raw, str) else None
        opt_labels = [
            (l.strip() if isinstance(l, str) else None) for l in opt_labels_raw
        ]

        # 1）题干：优先用 ekar_test.json 的 question；如果 ekar 里没有，就用 analysis 的 question_text
        if ekar_item is not None:
            q_text = ekar_item.get("question")
        else:
            q_text = item.get("question_text")

        obj_a, obj_b = extract_pair(q_text)

        if obj_a and obj_b and q_label in label_to_index:
            grid_index = label_to_index[q_label]
            rows.append((obj_a, obj_b, grid_index))
        else:
            # 统计问题：label 找不到 or 文本不是合法词对
            if not (obj_a and obj_b):
                bad_question_text_count += 1
            if q_label and q_label not in label_to_index:
                missing_q_label_index += 1

        # 2）选项：只能在 ekar_test 中找选项文本；如果 ekar 里没这个 id，选项就没法生成
        if ekar_item is None:
            continue

        choices = ekar_item.get("choices", {}) or {}
        choice_texts = choices.get("text", []) or []

        for i, opt_label in enumerate(opt_labels):
            if i >= len(choice_texts):
                # 防止 opt_labels 和 choices.text 长度不一致
                continue

            text = choice_texts[i]
            obj_a, obj_b = extract_pair(text)
            if not (obj_a and obj_b):
                bad_option_text_count += 1
                continue

            if not opt_label or opt_label not in label_to_index:
                missing_opt_label_index += 1
                continue

            grid_index = label_to_index[opt_label]
            rows.append((obj_a, obj_b, grid_index))

    # 写出 CSV
    with open(output_path, "w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Obj_A", "Obj_B", "grid_index"])
        writer.writerows(rows)

    # 打印统计
    print(f"analysis_consistency_result 题目总数: {total_items}")
    print(f"成功写出的 (Obj_A, Obj_B, grid_index) 行数: {len(rows)}")
    print(f"在 ekar_test.json 中找不到对应 id 的题目数: {missing_ekar_count}")
    print(f"题干 label 在 label_map 中找不到的数量: {missing_q_label_index}")
    print(f"选项 label 在 label_map 中找不到的数量: {missing_opt_label_index}")
    print(f"题干文本无法抽出两词对 (Obj_A, Obj_B) 的数量: {bad_question_text_count}")
    print(f"选项文本无法抽出两词对 (Obj_A, Obj_B) 的数量: {bad_option_text_count}")
    print(f"写出文件: {output_path}")


if __name__ == "__main__":
    main()

analysis_consistency_result 题目总数: 335
成功写出的 (Obj_A, Obj_B, grid_index) 行数: 1675
在 ekar_test.json 中找不到对应 id 的题目数: 0
题干 label 在 label_map 中找不到的数量: 0
选项 label 在 label_map 中找不到的数量: 0
题干文本无法抽出两词对 (Obj_A, Obj_B) 的数量: 0
选项文本无法抽出两词对 (Obj_A, Obj_B) 的数量: 0
写出文件: d:\VH_analogy\20260201_analogyreason\obj_grid_index.csv


In [2]:
from dotenv import load_dotenv
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

In [5]:
# ================= 配置区域 =================
MODEL_ID = "google/gemma-2-9b-it"
OUTPUT_FILE = "RelationBook_train_valid.jsonl"

# 本地数据文件路径
TRAIN_FILE = r"D:\VH_analogy\ekar_train_cleaned.json"
VALID_FILE = r"D:\VH_analogy\ekar_valid_cleaned.json"

# --- 新增：批处理大小 ---
BATCH_SIZE = 64

os.environ['HF_HUB_ENDPOINT'] = os.getenv("HF_ENDPOINT")
# ================= 模型初始化 (针对 96GB 显存优化) =================
print(f"正在加载模型: {MODEL_ID} ...")

# --- 修改点：为 Tokenizer 添加 padding_side 以支持批处理 ---
# 对于解码器-only 的模型 (如 Gemma), padding 需要在左侧
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side='left')
# 如果分词器没有默认的 pad token, 手动设置为 eos_token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cuda",  # 如果是单张大卡用 "cuda"，如果是多卡且想自动分配用 "auto"
    torch_dtype=torch.bfloat16, 
    token=HF_TOKEN
)
print(">>> 模型已以 BF16 加载 ")


# ================= Prompt 构造函数 (保持不变) =================
def construct_prompt(item):
    question = item['question']
    choices = item['choices'] 
    answer_key = item['answerKey']
    
    # 处理解析
    raw_explanation = item.get('explanation', [])
    if isinstance(raw_explanation, list):
        explanation_str = "\n".join(raw_explanation)
    else:
        explanation_str = str(raw_explanation)

    # 提取选项
    choice_texts = choices['text']
    correct_idx = ord(answer_key) - ord('A')
    correct_choice_text = choice_texts[correct_idx] if 0 <= correct_idx < 4 else "未知"

    # Prompt 内容
    prompt_content = f"""以下是一个处理范例：
输入：
题目词对：生病:医药
选项：A.空虚：信仰 B.糊涂：明白 C.雷雨：大风 D.难过：高兴
正确答案：A
参考解析：
[
"改善"生病"的状况可以用"医药"。",
"改善"空虚"的状况可以借助"信仰"。",
""糊涂"和"明白"的意思相反，"明白"不能改善"糊涂"的状况。",
""雷雨"和"大风"都属于自然现象，"雷雨"的状况不可以由"大风"改善。",
""难过"和"高兴"的意思相反，"难过"的状况不一定可以由"高兴"改善。"
]
输出：注意这里的关系描述都是<词1>相对<词2>的关系，请严格按照这个格式输出。
{{
"question_relation": "可以被治疗或改善",
"relation_A": "可以被治疗或改善",
"relation_B": "相反的意思",
"relation_C": "并列伴随关系",
"relation_D": "相反的意思"
}}

请分析以下数据并提取关系：注意提取的关系描述都是<词1>相对<词2>的关系

**题目词对**: {question}
**选项**:
A: {choice_texts[0]}
B: {choice_texts[1]}
C: {choice_texts[2]}
D: {choice_texts[3]}

**正确答案**: {answer_key} (即选项 {correct_choice_text})
**参考解析**: 
{explanation_str}

### 思考步骤
思考的步骤如下：
1. 分析【题目词对】的逻辑关系。
2. 分析【正确选项】的逻辑关系，寻找它与题目词对的共同点（这就是最佳颗粒度）。
3. 检查这个共同关系是否适用于错误选项。如果适用，说明颗粒度太粗，需要增加限定条件；如果不适用，则保留。
4. 对每个错误选项，分析其逻辑关系
5. 确保提取的关系是从<词1>相对<词2>的视角出发的。
### 输出格式

请直接输出一个 JSON 对象，不要包含其他废话，格式如下：
{{
"question_relation": "题目的词对之间的关系",
"relation_A": "选项A的词对之间的关系",
"relation_B": "选项B的词对之间的关系",
"relation_C": "选项C的词对之间的关系",
"relation_D": "选项D的词对之间的关系"
}}"""
    return prompt_content

# ================= 推理函数 =================
def get_extractions_from_gemma_batch(batch_items):
    """
    构造 Prompt 并获取模型对一个批次的提取结果
    """
    all_messages = []

    # 为批次中的每个项目创建 Prompt
    for item in batch_items:
        prompt_text = construct_prompt(item)
        messages = [{"role": "user", "content": prompt_text}]
        all_messages.append(messages)

    # 使用 Gemma 的标准对话模板，并进行 padding
    input_ids = tokenizer.apply_chat_template(
        all_messages,
        return_tensors="pt",
        add_generation_prompt=True,
        padding=True,  # 启用填充
        truncation=True  # 启用截断
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<end_of_turn>")
    ]

    # 生成配置: 贪婪解码(do_sample=False)以获得最稳定的提取
    outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        eos_token_id=terminators,
        do_sample=False,  # 贪婪解码，保证逻辑提取的一致性
    )

    # 解码 (只取生成的新 token)
    input_ids_len = input_ids.shape[1]
    generated_tokens = outputs[:, input_ids_len:]
    responses = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

    # 返回一个包含所有原始输出的列表
    return [res.strip() for res in responses]

def process_dataset():
    """从本地JSON文件加载数据"""
    print("正在加载本地数据集...")
    
    all_data = []
    
    # 加载训练集
    try:
        with open(TRAIN_FILE, 'r', encoding='utf-8') as f:
            train_data = json.load(f)
            all_data.extend(train_data)
        print(f"✓ 加载训练集: {len(train_data)} 条数据")
    except FileNotFoundError:
        print(f"✗ 找不到文件: {TRAIN_FILE}")
        return
    except Exception as e:
        print(f"✗ 加载训练集失败: {e}")
        return
    
    # 加载验证集
    try:
        with open(VALID_FILE, 'r', encoding='utf-8') as f:
            valid_data = json.load(f)
            all_data.extend(valid_data)
        print(f"✓ 加载验证集: {len(valid_data)} 条数据")
    except FileNotFoundError:
        print(f"✗ 找不到文件: {VALID_FILE}")
        return
    except Exception as e:
        print(f"✗ 加载验证集失败: {e}")
        return
    
    print(f"总数据量: {len(all_data)} 条")

    # 断点续传
    processed_ids = set()
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    record = json.loads(line)
                    if 'id' in record: 
                        processed_ids.add(record['id'])
                except: 
                    pass
    print(f"跳过已处理: {len(processed_ids)} 条")

    # 主循环 - 按批次处理
    with open(OUTPUT_FILE, 'a', encoding='utf-8') as f_out:
        for i in tqdm(range(0, len(all_data), BATCH_SIZE), desc="Gemma 本地数据处理"):
            # 获取当前批次的数据
            batch = all_data[i:i + BATCH_SIZE]
            
            # 过滤已处理的
            batch_to_process = [item for item in batch if item['id'] not in processed_ids]
            
            if not batch_to_process:
                continue
            
            try:
                # 批量获取模型的原始输出
                responses = get_extractions_from_gemma_batch(batch_to_process)

                # 遍历当前批次的结果并写入文件
                for item, response_text in zip(batch_to_process, responses):
                    item_id = item['id']
                    
                    # 清洗 JSON
                    clean_json_str = response_text.replace("```json", "").replace("```", "").strip()
                    try:
                        model_output = json.loads(clean_json_str)
                    except json.JSONDecodeError:
                        model_output = {"error": "JSON_DECODE_FAIL", "raw_content": response_text}

                    final_record = {
                        "id": item_id,
                        "original_question": item['question'],
                        "model_extraction": model_output
                    }
                    
                    f_out.write(json.dumps(final_record, ensure_ascii=False) + "\n")
                    f_out.flush()

            except Exception as e:
                print(f"处理批次 {i//BATCH_SIZE + 1} 时出错 (ID 从 {batch[0]['id']} 开始): {e}")
                # 你可以选择在这里为批次中的所有项目写入错误信息，或者跳过
                for item in batch_to_process:
                    item_id = item['id']
                    final_record = {
                        "id": item_id,
                        "original_question": item['question'],
                        "model_extraction": {"error": str(e)}
                    }
                    f_out.write(json.dumps(final_record, ensure_ascii=False) + "\n")
                    f_out.flush()
                continue

if __name__ == "__main__":
    process_dataset()

正在加载模型: google/gemma-2-9b-it ...


Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it]


>>> 模型已以 BF16 加载 
正在加载本地数据集...
✓ 加载训练集: 750 条数据
✓ 加载验证集: 113 条数据
总数据量: 863 条
跳过已处理: 15 条


Gemma 本地数据处理:   0%|          | 0/14 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Gemma 本地数据处理: 100%|██████████| 14/14 [04:17<00:00, 18.37s/it]


In [7]:
import json

# ================= 配置区域 =================
INPUT_FILE = "RelationBook_train_valid.jsonl"
OUTPUT_FILE = "relations_output.txt"

def extract_relations():
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f_out:
        with open(INPUT_FILE, 'r', encoding='utf-8') as f_in:
            for line in f_in:
                try:
                    record = json.loads(line)
                    model_extraction = record.get('model_extraction', {})
                    question_relation = model_extraction.get('question_relation', '')
                    relation_A = model_extraction.get('relation_A', '')
                    relation_B = model_extraction.get('relation_B', '')
                    relation_C = model_extraction.get('relation_C', '')
                    relation_D = model_extraction.get('relation_D', '')
                    
                    # 写入五个关系，每行一个（纵列排列）
                    f_out.write(f"{question_relation}\n{relation_A}\n{relation_B}\n{relation_C}\n{relation_D}\n")
                except Exception as e:
                    print(f"Error processing line: {e}")
                    continue
    
    print(f"✅ 已成功提取关系到 {OUTPUT_FILE}")

if __name__ == "__main__":
    extract_relations()

✅ 已成功提取关系到 relations_output.txt


In [13]:
Relation_Labels = [
    "同义关系",
    "近义关系",
    "反义关系",
    "包含关系（整体-部分）",
    "部分-整体关系（组成部分）",
    "种属关系（词1是词2的一种）",
    "类别关系（词1和词2是同类）",
    "并列关系",
    "对立矛盾关系",
    "互补关系",
    "可能因果关系",
    "必然因果关系",
    "必要条件关系",
    "充分条件关系",
    "目的关系",
    "结果关系",
    "递进关系",
    "工具-用途关系",
    "工具-使用者关系",
    "工具-对象关系",
    "场所-活动关系",
    "场所-人员关系",
    "场所-物品关系",
    "相邻位置关系",
    "位于关系",
    "时间顺序关系",
    "时间同时关系",
    "时间间隔关系",
    "时期朝代关系",
    "节气关系",
    "省份关系",
    "城市关系",
    "国家关系",
    "亲属关系",
    "师生关系",
    "上下级关系",
    "同事关系",
    "服务关系",
    "敌对关系",
    "职业-工作内容关系",
    "职业-工具关系",
    "作者-作品关系",
    "作品-人物关系",
    "作品-思想关系",
    "象征关系",
    "比喻关系",
    "标志关系",
    "产地关系",
    "材料-成品关系",
    "功能关系",
    "属性特征关系",
    "数量关系",
    "单位-物理量关系",
    "语法关系-主谓",
    "语法关系-动宾",
    "语法关系-偏正",
    "成语结构关系",
    "音译关系",
    "简称关系",
    "别名关系",
    "演变关系（古今、阶段）",
    "替代关系",
    "配套使用关系",
    "使用关系",
    "依赖关系",
    "控制关系",
    "领导关系",
    "合作关系",
    "对抗关系",
    "法律关系",
    "经济关系",
    "文化习俗关系",
    "自然现象关系",
    "物理关系",
    "化学关系",
    "数学关系",
    "抽象概念关系",
    "具体事物关系",
    "可能关系",
    "无逻辑关系",
    "不确定关系"
]

In [15]:
INPUT_FILE = "RelationBook_train_valid.jsonl"
OUTPUT_FILE = "mapped_relations.csv"
BATCH_SIZE = 128

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side='left')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)
print("模型已加载完成")

# ================= 数据提取函数 =================
def extract_relations_from_jsonl(file_path):
    """
    从 JSONL 文件中提取所有关系字段（question_relation 和 relation_A~D）
    返回列表，每个元素为字典：{'id': 带后缀的唯一标识, 'relation': 关系字符串}
    """
    items = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                data = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"第 {line_num} 行 JSON 解析错误: {e}")
                continue

            base_id = data.get('id', f'line_{line_num}')
            model_ext = data.get('model_extraction', {})

            # 提取 question_relation
            q_rel = model_ext.get('question_relation')
            if q_rel:
                items.append({
                    'id': f"{base_id}_question",
                    'relation': q_rel
                })

            # 提取 relation_A 到 relation_D
            for suffix in ['A', 'B', 'C', 'D']:
                rel = model_ext.get(f'relation_{suffix}')
                if rel:
                    items.append({
                        'id': f"{base_id}_{suffix}",
                        'relation': rel
                    })

    return items

# ================= 批处理推理函数 =================
def get_labels_from_gemma_batch(batch_items, relation_labels):
    """
    对一批关系文本进行分类，返回模型原始输出列表
    """
    # 构建候选标签字符串
    tags_str = "\n".join([f"- {tag}" for tag in relation_labels])

    all_messages = []
    for item in batch_items:
        user_content = f"""你是一位关系分类专家。你的任务是将给定的关系短语归类到预定义的标准关系类型列表中。

### 标准关系类型列表：
{tags_str}

### 待分类的关系短语：
{item['relation']}

### 指令：
请从“标准关系类型列表”中选择最匹配的一个类型。严格仅输出列表中的类型名称，不要输出任何其他文本。
"""
        messages = [{"role": "user", "content": user_content}]
        all_messages.append(messages)

    # 批量 tokenize
    input_ids = tokenizer.apply_chat_template(
        all_messages,
        return_tensors="pt",
        add_generation_prompt=True,
        padding=True,
        truncation=True
    ).to(model.device)

    # 生成
    outputs = model.generate(
        input_ids,
        max_new_tokens=32,
        do_sample=False,
        temperature=None,
        top_p=None
    )

    # 提取新生成的 token
    input_len = input_ids.shape[1]
    generated = outputs[:, input_len:]
    responses = tokenizer.batch_decode(generated, skip_special_tokens=True)

    return [res.strip() for res in responses]

def clean_label(raw_output, relation_labels):
    """
    清洗模型输出，映射到 relation_labels 中的某个标签
    """
    # 去除可能的 "类别：" 等前缀和标点
    cleaned = raw_output.replace("Category:", "").replace("类别：", "").strip()
    cleaned = cleaned.rstrip("。.")

    # 完全匹配
    if cleaned in relation_labels:
        return cleaned

    # 模糊匹配：检查 cleaned 是否包含某个标签，或标签包含 cleaned
    for tag in relation_labels:
        if tag in cleaned or cleaned in tag:
            return tag

    return "Uncategorized"

# ================= 主程序 =================
def main():
    # 1. 读取数据
    print("正在从 JSONL 提取关系...")
    all_items = extract_relations_from_jsonl(INPUT_FILE)
    print(f"共提取 {len(all_items)} 条关系记录。")

    if not all_items:
        print("未提取到任何关系，程序退出。")
        return

    # 2. 确定关系标签列表
    relation_labels = Relation_Labels
    if not relation_labels:
        # 自动收集唯一关系值作为默认标签
        unique_relations = sorted(set(item['relation'] for item in all_items))
        relation_labels = unique_relations
        print(f"未定义 Relation_Labels，已自动从数据中收集 {len(relation_labels)} 个唯一关系作为标签。")
    else:
        print(f"使用用户定义的标签列表，共 {len(relation_labels)} 个标签。")

    # 3. 批处理推理
    print(f"开始批处理推理 (Batch Size: {BATCH_SIZE})...")
    with open(OUTPUT_FILE, 'w', newline='', encoding='utf-8-sig') as csvfile:
        fieldnames = ['id', 'relation', 'label', 'raw_output']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for i in tqdm(range(0, len(all_items), BATCH_SIZE), desc="Processing Batches"):
            batch = all_items[i:i+BATCH_SIZE]
            try:
                raw_outputs = get_labels_from_gemma_batch(batch, relation_labels)

                for item, raw_out in zip(batch, raw_outputs):
                    label = clean_label(raw_out, relation_labels)
                    writer.writerow({
                        'id': item['id'],
                        'relation': item['relation'],
                        'label': label,
                        'raw_output': raw_out
                    })
            except Exception as e:
                print(f"处理批次 {i//BATCH_SIZE + 1} 时出错: {e}")
                # 将错误写入 CSV
                for item in batch:
                    writer.writerow({
                        'id': item['id'],
                        'relation': item['relation'],
                        'label': 'Error',
                        'raw_output': str(e)
                    })

    print(f"\n处理完成！结果已保存至: {OUTPUT_FILE}")

if __name__ == "__main__":
    main()


Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.86s/it]


模型已加载完成
正在从 JSONL 提取关系...
共提取 4315 条关系记录。
使用用户定义的标签列表，共 81 个标签。
开始批处理推理 (Batch Size: 128)...


Processing Batches:   0%|          | 0/34 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Processing Batches: 100%|██████████| 34/34 [04:09<00:00,  7.35s/it]


处理完成！结果已保存至: mapped_relations.csv





In [16]:
import json
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv

# 加载环境变量
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

# ================= 配置区域 =================
MODEL_ID = "google/gemma-2-9b-it"
INPUT_FILE = r"D:\VH_analogy\ekar_test_cleaned.json"
OUTPUT_FILE = "ekar_test_relations.jsonl"
BATCH_SIZE = 64

# ================= 模型初始化 =================
print(f"正在加载模型: {MODEL_ID} ...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side='left', token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cuda",
    torch_dtype=torch.bfloat16,
    token=HF_TOKEN
)
print(">>> 模型已以 BF16 加载完成")

# ================= Prompt 构造函数 =================
def construct_prompt(item):
    q_pair = item['question']
    choices = item['choices']['text']
    answer_key = item['answerKey']

    prompt = f"""作为一个语言逻辑专家，你的任务是分析类比推理题。
输入：
题目: {q_pair}
选项:
A: {choices[0]}
B: {choices[1]}
C: {choices[2]}
D: {choices[3]}

任务：
1. 分析【题目词对】的核心逻辑关系。
2. 独立分析每个【选项词对】的逻辑关系。
3. 不要输出答案。

格式要求：
- 必须确保提取的关系是从<词1>相对<词2>的视角出发的
- 仅输出 JSON。

输出示例：
{{
    "q_relation": "词1是词2的组成部分",
    "a_relation": "位置是高度的决定因素",
    "b_relation": "...",
    "c_relation": "...",
    "d_relation": "..."
}}"""
    return prompt

# ================= 批处理推理函数 =================
def get_relations_from_gemma_batch(batch_items):
    """
    批量获取 Gemma 的关系提取结果
    """
    all_messages = []
    for item in batch_items:
        prompt_text = construct_prompt(item)
        messages = [{"role": "user", "content": prompt_text}]
        all_messages.append(messages)

    input_ids = tokenizer.apply_chat_template(
        all_messages,
        return_tensors="pt",
        add_generation_prompt=True,
        padding=True,
        truncation=True
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<end_of_turn>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        eos_token_id=terminators,
        do_sample=False,
    )

    input_ids_len = input_ids.shape[1]
    generated_tokens = outputs[:, input_ids_len:]
    responses = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

    return [res.strip() for res in responses]

# ================= 主程序 =================
def main():
    # 读取数据
    print(f"正在读取 {INPUT_FILE} ...")
    with open(INPUT_FILE, 'r', encoding='utf-8') as f:
        data = json.load(f)

    print(f"共 {len(data)} 条数据")

    # 断点续传
    processed_ids = set()
    if os.path.exists(OUTPUT_FILE):
        with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    record = json.loads(line)
                    if 'id' in record:
                        processed_ids.add(record['id'])
                except:
                    pass
    print(f"跳过已处理: {len(processed_ids)} 条")

    # 主循环
    with open(OUTPUT_FILE, 'a', encoding='utf-8') as f_out:
        for i in tqdm(range(0, len(data), BATCH_SIZE), desc="Gemma 关系提取"):
            batch = data[i:i + BATCH_SIZE]
            batch_to_process = [item for item in batch if item['id'] not in processed_ids]

            if not batch_to_process:
                continue

            try:
                responses = get_relations_from_gemma_batch(batch_to_process)

                for item, response_text in zip(batch_to_process, responses):
                    item_id = item['id']

                    # 清洗 JSON
                    clean_json_str = response_text.replace("```json", "").replace("```", "").strip()
                    try:
                        relations = json.loads(clean_json_str)
                    except json.JSONDecodeError:
                        relations = {"error": "JSON_DECODE_FAIL", "raw_content": response_text}

                    final_record = {
                        "id": item_id,
                        "relations": relations
                    }

                    f_out.write(json.dumps(final_record, ensure_ascii=False) + "\n")
                    f_out.flush()

            except Exception as e:
                print(f"处理批次 {i//BATCH_SIZE + 1} 时出错: {e}")
                for item in batch_to_process:
                    final_record = {
                        "id": item['id'],
                        "relations": {"error": str(e)}
                    }
                    f_out.write(json.dumps(final_record, ensure_ascii=False) + "\n")
                    f_out.flush()

    print(f"处理完成，结果保存到 {OUTPUT_FILE}")

if __name__ == "__main__":
    main()

正在加载模型: google/gemma-2-9b-it ...


Loading checkpoint shards: 100%|██████████| 4/4 [00:12<00:00,  3.19s/it]


>>> 模型已以 BF16 加载完成
正在读取 D:\VH_analogy\ekar_test_cleaned.json ...
共 205 条数据
跳过已处理: 0 条


Gemma 关系提取:   0%|          | 0/4 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Gemma 关系提取: 100%|██████████| 4/4 [00:43<00:00, 10.98s/it]

处理完成，结果保存到 ekar_test_relations.jsonl



