# 用于处理多模态RAG图文问答挑战赛训练集JSON文件，提取问题与答案构建QA对
本notebook将演示如何将原始训练集转换为标准QA格式，便于后续微调。

In [1]:
#

In [None]:
import json
import os
import re
import logging
from typing import Dict, List, Any, Optional
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler("finance_data_process.log"), logging.StreamHandler()]
)

# 扩展金融领域实体词表（增强领域对齐）
FINANCIAL_ENTITIES = {
    "指标": [
        "营收", "营业收入", "净利润", "毛利润", "毛利率", "净利率",
        "资产负债率", "ROE", "ROA", "同比增长率", "环比增长率",
        "每股收益", "市盈率", "市净率", "现金流", "资产总额",
        "负债总额", "股东权益", "研发费用", "销售费用"
    ],
    "主体": [
        "公司", "企业", "银行", "基金", "股票", "债券",
        "上市公司", "母公司", "子公司", "金融机构", "券商",
        "保险公司", "信托公司"
    ],
    "文档类型": [
        "财报", "研报", "年报", "季报", "中报", "公告",
        "招股书", "债券募集说明书", "评级报告", "尽职调查报告"
    ],
    "时间维度": ["年度", "季度", "月度", "半年度", "同比", "环比"]
}

# 加载预训练模型（复用避免重复加载）
class ModelCache:
    _qg_model = None
    _paraphrase_model = None
    _tokenizer = None

    @classmethod
    def get_qg_pipeline(cls):
        if cls._qg_model is None:
            try:
                cls._qg_model = pipeline(
                    "text2text-generation",
                    model="lmqg/t5-base-finance-qg",
                    device=0 if torch.cuda.is_available() else -1
                )
                logging.info("成功加载金融问题生成模型")
            except Exception as e:
                logging.error(f"金融问题生成模型加载失败: {str(e)}")
                raise
        return cls._qg_model

    @classmethod
    def get_paraphrase_model(cls):
        if cls._paraphrase_model is None:
            try:
                cls._tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
                cls._paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(
                    "Vamsi/T5_Paraphrase_Paws"
                ).to("cuda" if torch.cuda.is_available() else "cpu")
                logging.info("成功加载同义改写模型")
            except Exception as e:
                logging.error(f"同义改写模型加载失败: {str(e)}")
                raise
        return cls._paraphrase_model, cls._tokenizer


def extract_finance_entities(text: str) -> Dict[str, List[str]]:
    """增强版金融实体提取，支持部分匹配和上下文关联"""
    entities = {k: [] for k in FINANCIAL_ENTITIES.keys()}
    text_lower = text.lower()

    for category, keywords in FINANCIAL_ENTITIES.items():
        for keyword in keywords:
            # 支持部分匹配（如"营收"匹配"营业收入"）
            pattern = re.compile(rf"\b{re.escape(keyword)}\b|{re.escape(keyword)}[率额值度]", re.IGNORECASE)
            if pattern.search(text_lower):
                entities[category].append(keyword)

    # 去重并保持顺序
    for k in entities:
        entities[k] = list(dict.fromkeys(entities[k]))
    return entities


def is_answer_consistent(question: str, answer: str, context: str = "") -> bool:
    """增强版答案校验（金融领域专用）"""
    if not answer.strip():
        logging.warning("过滤空答案样本")
        return False

    # 检查来源引用格式（金融文档规范）
    source_pattern = r"([\u4e00-\u9fa5\w]+\.pdf).*?(第\d+页|Page \d+)"
    if not re.search(source_pattern, answer):
        # 无显式引用时强化实体关联检查
        question_ents = extract_finance_entities(question)
        answer_ents = extract_finance_entities(answer)
        context_ents = extract_finance_entities(context) if context else {}

        # 至少需要共享一个核心金融指标
        shared_indicators = set(question_ents["指标"]) & set(answer_ents["指标"])
        if context and not shared_indicators:
            logging.warning(f"答案与问题无共享指标: {question[:30]}...")
            return False

    # 检查数值合理性（金融数据格式校验）
    num_pattern = r"(\d+,\d+|\d+)(\.\d+)?(万|千|百|亿|万亿)?(元|美元|%)"
    if re.search(num_pattern, answer):
        # 检查单位与指标匹配（如"毛利率"应带"%"）
        for indicator in FINANCIAL_ENTITIES["指标"]:
            if indicator in question and "率" in indicator:
                if "%" not in answer:
                    logging.warning(f"比率指标答案缺少百分号: {indicator}")
                    return False

    return True


def paraphrase_question(question: str, num_variations: int = 2) -> List[str]:
    """使用T5模型生成同义问题（数据增强）"""
    model, tokenizer = ModelCache.get_paraphrase_model()
    inputs = tokenizer(
        f"paraphrase: {question}",
        padding="max_length",
        max_length=128,
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(
        **inputs,
        max_length=128,
        num_return_sequences=num_variations,
        do_sample=True,
        top_k=50,
        temperature=0.7
    )

    variations = [
        tokenizer.decode(output, skip_special_tokens=True).strip()
        for output in outputs
    ]
    # 过滤重复变体
    return list(dict.fromkeys(variations))


def generate_hard_negatives(question: str, correct_answer: str, context: str) -> List[str]:
    """生成硬负例（相似问题但错误的答案）"""
    negatives = []
    if not context:
        return negatives

    # 提取上下文其他部分作为负例来源
    context_chunks = re.split(r"。|；|！|\n", context)
    context_chunks = [c.strip() for c in context_chunks if len(c.strip()) > 10]

    # 选择与正确答案长度相近的错误片段
    correct_len = len(correct_answer)
    for chunk in context_chunks:
        if abs(len(chunk) - correct_len) < 50 and chunk not in correct_answer:
            # 替换关键实体生成迷惑性负例
            ents = extract_finance_entities(correct_answer)
            for indicator in ents["指标"][:1]:  # 替换一个核心指标
                chunk = re.sub(indicator, f"错误{indicator}", chunk)
            negatives.append(chunk)
            if len(negatives) >= 2:  # 每个问题生成2个负例
                break
    return negatives


def augment_data(context: str, original_question: str, original_answer: str) -> List[Dict[str, Any]]:
    """综合数据增强：问题改写+新问题生成+硬负例"""
    augmented = []
    qg_pipeline = ModelCache.get_qg_pipeline()

    # 1. 同义问题改写
    try:
        paraphrased = paraphrase_question(original_question)
        for p in paraphrased:
            augmented.append({
                "type": "paraphrase",
                "context": context,
                "question": p,
                "answer": original_answer,
                "is_positive": True
            })
    except Exception as e:
        logging.error(f"同义改写失败: {str(e)}")

    # 2. 基于上下文生成新问题
    try:
        generated_questions = qg_pipeline(
            context,
            max_length=100,
            num_return_sequences=2,
            temperature=0.8
        )
        for q in generated_questions:
            augmented.append({
                "type": "new_question",
                "context": context,
                "question": q["generated_text"],
                "answer": original_answer,
                "is_positive": True
            })
    except Exception as e:
        logging.error(f"新问题生成失败: {str(e)}")

    # 3. 生成硬负例（供RLHF偏好数据使用）
    try:
        hard_negatives = generate_hard_negatives(original_question, original_answer, context)
        for neg in hard_negatives:
            augmented.append({
                "type": "hard_negative",
                "context": context,
                "question": original_question,
                "answer": neg,
                "is_positive": False
            })
    except Exception as e:
        logging.error(f"硬负例生成失败: {str(e)}")

    return augmented


def build_finance_sft_data(
    raw_data: List[Dict[str, Any]],
    chunk_data_path: str = "all_pdf_page_chunks.json",
    include_rlhf_data: bool = True
) -> Dict[str, List[Dict[str, Any]]]:
    """构建增强版金融SFT数据，包含RLHF所需的偏好数据结构"""
    # 加载PDF分页上下文
    chunk_context = {}
    if os.path.exists(chunk_data_path):
        with open(chunk_data_path, "r", encoding="utf-8") as f:
            chunks = json.load(f)
            for chunk in chunks:
                key = f"{chunk['metadata']['file_name']}_page_{chunk['metadata']['page']}"
                chunk_context[key] = chunk['content']
        logging.info(f"已加载 {len(chunk_context)} 条PDF分页上下文")
    else:
        logging.warning(f"未找到PDF上下文文件: {chunk_data_path}")

    # 初始化数据容器
    output = {
        "sft_data": [],          # 用于监督微调的数据
        "rlhf_preference_data": []  # 用于RLHF的偏好数据
    }

    # 处理原始数据
    for idx, item in enumerate(raw_data):
        if idx % 100 == 0:
            logging.info(f"处理第 {idx}/{len(raw_data)} 条原始数据")

        question = item.get("question", "").strip()
        answer = item.get("answer", "").strip()
        file_name = item.get("file_name", "未知")
        page = item.get("page", "未知")
        context_key = f"{file_name}_page_{page}"
        context = chunk_context.get(context_key, "")

        # 过滤无效样本
        if not question or not answer:
            logging.warning(f"跳过无效样本 (问题或答案为空): {question[:30]}...")
            continue
        if not is_answer_consistent(question, answer, context):
            logging.warning(f"跳过不一致样本: {question[:30]}...")
            continue

        # 提取金融实体
        entities = extract_finance_entities(question + " " + answer)

        # 细化金融任务类型
        question_lower = question.lower()
        task_type = "金融问答"
        if any(ind in question_lower for ind in FINANCIAL_ENTITIES["指标"]):
            task_type = "财务指标提取"
        elif "同比" in question_lower or "环比" in question_lower or "趋势" in question_lower:
            task_type = "数据趋势分析"
        elif "解释" in question_lower or "含义" in question_lower or "什么是" in question_lower:
            task_type = "金融术语解读"
        elif "图表" in question_lower or "图形" in question_lower or "表格" in question_lower:
            task_type = "图文分析"
        elif "预测" in question_lower or "展望" in question_lower:
            task_type = "财务预测"

        # 构建SFT样本
        sft_sample = {
            "system": (
                "你是专业金融分析师助手，需基于财报、研报等金融文档内容回答问题。"
                "要求：1. 严格使用规范金融术语；2. 数据需标注来源（如XX.pdf第X页）；"
                "3. 区分事实陈述与分析观点；4. 图表相关问题需结合数据趋势解读。"
            ),
            "instruction": question,
            "context": context,
            "output": answer,
            "task_type": task_type,
            "entities": entities,
            "metadata": {
                "file_name": file_name,
                "page": page,
                "has_image": "图表" in question_lower or "图形" in question_lower,
                "raw_index": idx
            }
        }
        output["sft_data"].append(sft_sample)

        # 数据增强（生成额外样本）
        if context:
            augmented_samples = augment_data(context, question, answer)
            for aug in augmented_samples:
                # 增强样本加入SFT数据
                if aug["is_positive"]:
                    output["sft_data"].append({
                        **sft_sample,
                        "instruction": aug["question"],
                        "metadata": {
                            **sft_sample["metadata"],
                            "aug_type": aug["type"]
                        }
                    })
                # 硬负例加入RLHF偏好数据
                if include_rlhf_data and not aug["is_positive"]:
                    output["rlhf_preference_data"].append({
                        "question": question,
                        "context": context,
                        "chosen": answer,          # 优质答案
                        "rejected": aug["answer"], # 劣质答案（负例）
                        "metadata": sft_sample["metadata"]
                    })

    logging.info(f"数据处理完成 - SFT样本数: {len(output['sft_data'])}，RLHF偏好样本数: {len(output['rlhf_preference_data'])}")
    return output


def main(
    train_data_path: str = "data/train.json",
    chunk_data_path: str = "all_pdf_page_chunks.json",
    output_dir: str = "data/processed"
):
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)

    # 加载原始训练数据
    if not os.path.exists(train_data_path):
        raise FileNotFoundError(f"训练集文件不存在: {train_data_path}")
    with open(train_data_path, "r", encoding="utf-8") as f:
        raw_data = json.load(f)
    logging.info(f"加载原始训练数据 - 样本数: {len(raw_data)}")

    # 构建金融SFT数据（含RLHF偏好数据）
    processed_data = build_finance_sft_data(raw_data, chunk_data_path)

    # 保存SFT数据
    sft_json_path = os.path.join(output_dir, "finance_sft_train.json")
    with open(sft_json_path, "w", encoding="utf-8") as f:
        json.dump(processed_data["sft_data"], f, ensure_ascii=False, indent=2)
    logging.info(f"SFT数据已保存: {sft_json_path}")

    # 保存为CSV便于分析
    sft_df = pd.DataFrame(processed_data["sft_data"])
    sft_csv_path = os.path.join(output_dir, "finance_sft_train.csv")
    sft_df.to_csv(sft_csv_path, index=False, encoding="utf-8-sig")
    logging.info(f"SFT数据CSV已保存: {sft_csv_path}")

    # 保存RLHF偏好数据（供后续奖励模型训练）
    rlhf_json_path = os.path.join(output_dir, "finance_rlhf_preference.json")
    with open(rlhf_json_path, "w", encoding="utf-8") as f:
        json.dump(processed_data["rlhf_preference_data"], f, ensure_ascii=False, indent=2)
    logging.info(f"RLHF偏好数据已保存: {rlhf_json_path}")


if __name__ == "__main__":
    # 可通过命令行参数或环境变量指定路径
    import argparse
    parser = argparse.ArgumentParser(description="金融领域SFT数据构建工具（含RLHF支持）")
    parser.add_argument("--train_data", default="data/train.json", help="原始训练数据路径")
    parser.add_argument("--chunk_data", default="all_pdf_page_chunks.json", help="PDF分页上下文路径")
    parser.add_argument("--output_dir", default="data/processed", help="输出目录")
    args = parser.parse_args()

    main(
        train_data_path=args.train_data,
        chunk_data_path=args.chunk_data,
        output_dir=args.output_dir
    )

## 2. 加载JSON数据集
读取多模态RAG图文问答挑战赛训练集.json文件，解析为Python对象。

## 4. 构建QA对列表
遍历数据集，将每条数据的question和answer字段提取出来，构建QA对列表。指令字段统一设置为“你是一名专业的财报数据问答助手”。

## 5. 保存处理后的QA数据
将处理后的QA对列表保存为新的JSON或CSV文件，便于后续模型训练或分析。