In [None]:
# -- coding: utf-8 --
import os
import json
import time
import random
import re
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from tqdm import tqdm
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional, Tuple
import logging
from collections import Counter
import numpy as np
import requests
import math
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import gc # 添加gc模块用于内存清理

# 设置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("essay_evaluation.log", encoding='utf-8'),  # 指定UTF-8编码
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# 设置随机种子，确保结果可重现
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# 设置API相关参数
API_URL = "url"  # 请替换为您的API地址
API_MODEL = "deepseek-chat"   # 使用模型
API_MAX_TOKENS = 8192       # 增加token数量以获取更详细的分析
API_TEMPERATURE = 0.1         # 降低温度以获得更确定性的结果
API_STREAM = False

# API密钥列表 (需要替换为您的密钥)
API_KEYS = [
    "apikeys",
    "apikeys",
    "apikeys",
    "apikeys",
    "apikeys"
]

# 数据路径
data_root = "."  # JSON 文件所在的目录
input_file = os.path.join(data_root, "test_data.json")  # 修改为test_data.json
output_file = os.path.join(data_root, "DUFL2025_track1.json")  # 修改输出文件名为DUFL2025_track1.json
processed_essays_file = os.path.join(data_root, "processed_essays.json")
bert_model_path = os.path.join(data_root, "bert_essay_model.pt")
samples_file = os.path.join(data_root, "samples.json") # 新增：样本文件路径

# 定义分类标签和对应的分数
CLASSIFICATION_LABELS = {
    "优秀": 5,
    "较好": 4,
    "一般": 3,
    "合格": 2,
    "不合格": 1
}

# 反向映射，从分数到标签
SCORE_TO_LABEL = {v: k for k, v in CLASSIFICATION_LABELS.items()}

@dataclass
class Message:
    content: str
    role: str

@dataclass
class ApifoxModel:
    messages: List[Message]
    model: str
    logit_bias: None = None
    frequency_penalty: Optional[int] = None
    max_tokens: Optional[float] = None
    n: Optional[float] = None
    presence_penalty: Optional[int] = None
    stream: Optional[bool] = None
    temperature: Optional[int] = None
    top_p: Optional[int] = None
    user: Optional[str] = None

# API密钥管理
current_api_idx = 0
last_switch_time = time.time()

def switch_api_key():
    global current_api_idx, last_switch_time
    if not API_KEYS: # 增加检查防止空列表
        logger.error("API_KEYS list is empty. Cannot switch key.")
        return None
    current_api_idx = (current_api_idx + 1) % len(API_KEYS)
    last_switch_time = time.time()
    logger.info(f"Switching to new API key index {current_api_idx}")
    return API_KEYS[current_api_idx]

def get_current_api_key():
    global current_api_idx, last_switch_time
    current_time = time.time()
    if not API_KEYS: # 检查密钥列表是否为空
        raise ValueError("API_KEYS list is empty. Please provide valid API keys.")
    return API_KEYS[current_api_idx]

def call_llm_api(messages, retry_on_empty=True):
    """使用 requests 调用 LLM API (同步), 集成重试和密钥轮换"""
    url = f"https://{API_URL}/v1/chat/completions"
    model = ApifoxModel(
        model=API_MODEL,
        messages=[asdict(m) for m in messages], # 确保 messages 是 dict 列表
        max_tokens=API_MAX_TOKENS,
        temperature=API_TEMPERATURE,
        stream=API_STREAM,
    )

    # 移除值为None的字段以匹配API要求
    payload_dict = {k: v for k, v in asdict(model).items() if v is not None}
    payload = json.dumps(payload_dict)

    num_retries = 100 # 增加重试次数

    for attempt in range(num_retries):
        current_key = get_current_api_key()
        if not current_key: # 如果密钥为空（例如列表为空）
            logger.error("No valid API key available.")
            return "API_ERROR: No valid API key"

        headers = {
            "Authorization": f"Bearer {current_key}",
            "Content-Type": "application/json"
        }
        try:
            # 增加请求前的等待时间，避免过于频繁
            time.sleep(random.uniform(0.5, 1.5)) # 调整为更短的随机延时
            logger.debug(f"Attempt {attempt + 1}/{num_retries}: Calling API {url} with model {API_MODEL}")

            response = requests.post(url, data=payload, headers=headers, timeout=180) # 增加超时时间

            logger.debug(f"Attempt {attempt + 1}: Received status {response.status_code}")
            if response.status_code == 429: # Rate limit error
                logger.warning(f"Rate limit exceeded (429). Attempt {attempt + 1}/{num_retries}. Switching key and retrying after delay.")
                new_key = switch_api_key() # Switch key immediately on rate limit
                if not new_key: return "API_ERROR: No valid API key after switch" # Check if switch failed
                time.sleep(2**(attempt + 1) + random.uniform(0, 1)) # Exponential backoff
                continue # Retry immediately with new key

            response.raise_for_status() # Raise exception for other bad statuses (4xx, 5xx)
            json_data = response.json()
            logger.debug(f"Attempt {attempt + 1}: Received JSON response.")

            if "choices" in json_data and len(json_data["choices"]) > 0 and \
               "message" in json_data["choices"][0] and "content" in json_data["choices"][0]["message"]:
                content = json_data["choices"][0]["message"]["content"]
                # 检查是否为空响应
                if retry_on_empty and (not content or content.strip() == ""):
                    logger.warning(f"Empty API response content received (attempt {attempt + 1}/{num_retries})")
                    if attempt < num_retries - 1:
                        time.sleep(2**(attempt + 1))
                        continue
                    else:
                         # Final attempt failed due to empty response
                         logger.error("API returned empty response content after max retries.")
                         # Consider switching key here?
                         # new_key = switch_api_key()
                         return "API_ERROR: Empty response after max retries"
                logger.debug(f"Attempt {attempt + 1}: Successfully extracted content.")
                return content
            else:
                logger.error(f"Invalid API response structure: {json_data} (attempt {attempt + 1}/{num_retries})")

        except requests.exceptions.HTTPError as e:
            logger.error(f"HTTP Error (attempt {attempt + 1}/{num_retries}): Status {e.response.status_code}, URL: {e.request.url}")
            # Read response body for more details if available
            try:
                error_body = e.response.text
                logger.error(f"Error Body: {error_body[:500]}...") # Log first 500 chars
            except Exception as body_e:
                logger.error(f"Could not read error response body: {body_e}")

            if e.response.status_code == 401: # Unauthorized - likely bad key
                 logger.error("Authorization Error (401). Switching API key.")
                 new_key = switch_api_key()
                 if not new_key: return "API_ERROR: No valid API key after switch for 401"
            elif e.response.status_code == 400: # Bad Request - often payload issues
                 logger.error(f"Bad Request (400). Check API payload. Payload snippet: {payload[:200]}...")
                 # Don't retry on bad request usually, but maybe switch key just in case?
                 # return "API_ERROR: Bad Request (400)" # Stop retrying for 400
            # Other client errors might also warrant a key switch or just retry
        except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
            logger.error(f"Network/Connection Error (attempt {attempt + 1}/{num_retries}): {type(e).__name__} - {e}")
        except json.JSONDecodeError as e:
             logger.error(f"Failed to decode JSON response (attempt {attempt + 1}/{num_retries}): {e}")
             # Log the raw response text if possible
             try:
                 raw_text = response.text
                 logger.error(f"Raw response text (on JSON decode error): {raw_text[:500]}...") # Log first 500 chars
             except Exception as text_e:
                 logger.error(f"Could not get raw response text: {text_e}")
        except Exception as e:
            logger.error(f"Unexpected error during API call (attempt {attempt + 1}/{num_retries}): {type(e).__name__} - {e}", exc_info=True) # Log traceback for unexpected errors


        # Retry logic
        if attempt < num_retries - 1:
            wait_time = 2**(attempt + 1) + random.uniform(0, 1)
            logger.info(f"Retrying in {wait_time:.2f} seconds...")
            time.sleep(wait_time)  # 指数退避
        else:
            # 在最后一次重试失败后才切换密钥 (unless already switched for 401/429)
            logger.warning(f"Max retries reached for API call to {url}.")
            # Consider switching key here if not already done for specific errors like 401/429
            # _ = switch_api_key()
            return "API_ERROR: Max retries reached"

    # Fallback if loop finishes unexpectedly
    logger.error("API call loop completed without returning a value or error.")
    return "API_ERROR: Unknown failure after retries"

def extract_classification(response: str) -> str:
    """从LLM响应中提取分类结果"""
    if not response or 'API_ERROR' in response:
        return '一般'  # 默认分类

    # 优先匹配明确的分类标签格式 (包括可能在报告中明确指定的部分)
    # 增加对英文标签的匹配
    patterns = [
        r'[一二三四五]、?\s最终分类[:：\s]([优秀较好一般合格不合格]+)', # 匹配 "一、最终分类："
        r'最终分类[:：\s]([优秀较好一般合格不合格]+)',
        r'分类[:：\s]([优秀较好一般合格不合格]+)',
        r'切题度初步评级[:：\s]([优秀较好一般合格不合格]+)', # 来自 Expert 3
        r'素材切题度评级[:：\s]([优秀较好一般合格不合格]+)', # 来自 Expert 4
        r'素材服务主旨的总体有效性评级[:：\s]([优秀较好一般合格不合格]+)', # 来自 Expert 4 (alternative phrasing)
        r'Final Classification[:：\s]([ExcellentGoodAverageQualifiedUnqualified]+)',
        r'Classification[:：\s]*([ExcellentGoodAverageQualifiedUnqualified]+)',
        r'^([优秀较好一般合格不合格]+)', # 匹配只有分类结果的行 (中文)
        r'^([ExcellentGoodAverageQualifiedUnqualified]+)' # 匹配只有分类结果的行 (英文)
    ]

    # 中英文标签映射
    eng_to_chn = {"Excellent": "优秀", "Good": "较好", "Average": "一般", "Qualified": "合格", "Unqualified": "不合格"}

    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE) # Ignore case and allow ^ across lines
        if match:
            label = match.group(1).strip()
            # 如果是英文，转中文
            label = eng_to_chn.get(label.capitalize(), label) # Capitalize for consistent matching
            if label in CLASSIFICATION_LABELS:
                logger.debug(f"Extracted classification '{label}' using pattern: {pattern}")
                return label

    # 如果没有找到明确的标签格式，尝试直接匹配标签文本 (中英文)
    # Consider matching labels surrounded by quotes or specific markers
    sorted_labels = sorted(CLASSIFICATION_LABELS.keys(), key=len, reverse=True) # 中文优先
    all_labels_to_check = sorted_labels + list(eng_to_chn.keys())

    for label_text in all_labels_to_check:
        # 使用更严格的匹配，确保是独立的词或被标点/空格/边界包围
        pattern = r'(?:^|\s|[:：，。！；、$$$$$${}<>"\']+)(' + re.escape(label_text) + r')(?:$|\s|[:：，。！；、$$$$${}<>"\'])' # 修改结尾匹配
        match = re.search(pattern, response, re.IGNORECASE)
        if match:
            found_label = match.group(1)
            # 如果是英文，转中文
            chn_label = eng_to_chn.get(found_label.capitalize(), found_label)
            if chn_label in CLASSIFICATION_LABELS:
                logger.debug(f"Extracted classification '{chn_label}' by direct text match.")
                return chn_label

    # 默认返回
    logger.warning(f"Could not extract classification from response: {response[:200]}... Defaulting to '一般'.")
    return "一般"

# --- Few-Shot Example Formatting ---
def format_few_shot_examples(examples: List[Dict[str, Any]]) -> str:
    """将样本列表格式化为字符串，用于插入Prompt"""
    if not examples:
        return ""

    formatted_string = "以下是一些评估示例，包含每一个种评分的具体情况，可以参考它们的格式、分析思路和最终的分类/评论：\n\n"
    for i, example in enumerate(examples):
        formatted_string += f"--- 示例 {i+1} ---\n"
        formatted_string += f"年级: {example.get('grade', '未知')}\n"
        formatted_string += f"作文要求: {example.get('requirement', '无')}\n"
        formatted_string += f"作文标题: {example.get('title', '无')}\n"
        formatted_string += f"作文内容:\n{example.get('content', '无')}\n"
        formatted_string += f"最终分类: {example.get('classification', '未知')}\n"
        formatted_string += f"评审意见: {example.get('comment', '无')}\n"
        formatted_string += "---\n\n"

    return formatted_string

# 新增：作文类型识别智能体的提示
def generate_essay_type_classifier_prompt(essay_data: Dict[str, Any], few_shot_examples: List[Dict[str, Any]] = None) -> List[Message]:
    """
    生成作文类型识别智能体的提示 (Expert 0: Essay Type Classifier)
    """
    grade = essay_data.get("grade", "未知")
    requirement = essay_data.get("requirement", "无具体要求")
    title = essay_data.get("title", "无标题")
    content = essay_data.get("content", "无内容")

    system_prompt = f"""
你的角色: 中小学作文体裁鉴别专家。专注于分析作文的体裁类型。

你的任务: 对{grade}年级作文进行分析，识别其所属的文体类型，并提供依据。

分析要点:
1. 根据作文要求、标题和内容，判断作文属于哪种文体类型（记叙文、议论文、说明文、应用文、读后感等）
2. 分析作文的结构特点、表达方式、语言风格等，找出支持文体判断的关键证据
3. 考虑{grade}年级学生的写作水平和认知特点，判断文体的纯度和混合情况

输出要求:
生成《作文类型识别报告》，包含以下部分：
一、作文类型：明确指出作文属于哪种文体（必须是单一明确类型，如"记叙文"、"议论文"、"说明文"等）
二、类型判断依据：列出3-5点支持该判断的具体证据，引用原文关键部分
三、文体特征分析：简要分析该文体在本文中的典型表现和特点

强调: 准确、客观、有依据。你的判断将影响后续评估过程。
"""

    user_prompt = f"""
请根据上述指令，对以下{grade}年级学生作文进行类型识别：

作文要求：
{requirement}

作文标题：{title}

作文内容：
{content}

请生成《作文类型识别报告》。
"""

    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_prompt)
    ]
    return messages

# 提取作文类型的函数
def extract_essay_type(response: str) -> str:
    """从作文类型识别报告中提取作文类型"""
    if not response or 'API_ERROR' in response:
        return '记叙文'  # 默认类型

    # 匹配作文类型的模式
    patterns = [
        r'[一二三四五]、?\s*作文类型[:：\s]*([记叙议论说明应用读后感描写想象书信日记演讲辩论诗歌散文小说]+文)',
        r'作文类型[:：\s]*([记叙议论说明应用读后感描写想象书信日记演讲辩论诗歌散文小说]+文)',
        r'文体类型[:：\s]*([记叙议论说明应用读后感描写想象书信日记演讲辩论诗歌散文小说]+文)',
        r'类型[:：\s]*([记叙议论说明应用读后感描写想象书信日记演讲辩论诗歌散文小说]+文)',
        r'Essay Type[:：\s]*([NarrativeArgumentativeExpositoryPracticalReflectiveDescriptiveImaginativeLetterDiaryPoetryProseNovel]+)'
    ]

    # 英文类型映射到中文
    eng_to_chn = {
        "Narrative": "记叙文",
        "Argumentative": "议论文",
        "Expository": "说明文",
        "Practical": "应用文",
        "Reflective": "读后感",
        "Descriptive": "描写文",
        "Imaginative": "想象文"
    }

    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
        if match:
            essay_type = match.group(1).strip()
            # 如果是英文，转中文
            for eng, chn in eng_to_chn.items():
                if eng.lower() in essay_type.lower():
                    return chn
            # 检查是否是有效的中文类型
            valid_chn_types = ["记叙文", "议论文", "说明文", "应用文", "读后感", "描写文", "想象文", "书信", "日记", "演讲稿", "辩论稿", "诗歌", "散文", "小说"]
            if essay_type in valid_chn_types:
                return essay_type

    # 如果没有找到明确的类型，尝试直接匹配常见文体类型
    common_types = ["记叙文", "议论文", "说明文", "应用文", "描写文", "想象文", "童话故事", "推荐文", "读后感", "推荐类型", "混合型", "书信", "日记"]
    for essay_type in common_types:
        if essay_type in response:
            return essay_type

    # 默认返回
    logger.warning(f"Could not extract essay type from response: {response[:200]}... Defaulting to '记叙文'.")
    return "记叙文"


# --- Prompt Revisions Start ---
def generate_prompt_interpretation_prompt(essay_data: Dict[str, Any], essay_type: str = None, few_shot_examples: List[Dict[str, Any]] = None) -> List[Message]:
    """
    生成作文要求解读助手的提示 (Expert 1: Prompt Precision Analyst)
    """
    grade = essay_data.get("grade", "未知")
    requirement = essay_data.get("requirement", "无具体要求")
    essay_type_info = f"已识别的作文类型：{essay_type}" if essay_type else ""

    # --- Revised Prompt ---
    system_prompt = f"""
你的角色: 中小学作文题目解析与要求拆解师。专注于客观解构题目文本。

你的任务: 对{grade}年级作文题目进行细致解读，明确核心要求、限制条件、评价重点及潜在考察点，建立可衡量的切题度评估基准。

{essay_type_info}

分析步骤:

核心写作任务识别: 用一句话明确核心指令 (如: 记叙经历、描写景色、阐述观点)。 界定任务类型 (记叙、描写、议论等)。
关键限制条件提取: 以列表形式列出所有明确限制 (文体、内容、对象、时空、人称、字数、表达方式等)。 如无明确限制，需指出 "无显性XX限制"。
隐含要求与评价倾向分析**: 基于{grade}年级水平，分析可能的隐含要求 (情感基调、价值导向、思维能力、选材范围建议)。 需注明是推断。
切题度核心检查点: 提炼3-5个具体的、可操作的检查点 (尽可能接近"是/否"判断)，用于快速判断是否切题及深入程度。 例如：是否围绕[核心任务关键词]？是否包含[关键要素]？是否符合[明确限制]？
输出要求:
生成《作文题目要求精准解读报告》，包含以下四部分，内容对应上述步骤：
一、核心写作任务
二、关键限制条件
三、隐含要求与评价倾向
四、切题度核心检查点

强调: 客观、精确、全面、操作性强。你的输出是后续评估的基础。
"""

    user_prompt = f"""
请根据上述指令，对以下{grade}年级作文题目要求进行精准解读：

作文要求原文:
{requirement}

请生成《作文题目要求精准解读报告》。
"""
    # --- End Revised Prompt ---

    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_prompt)
    ]
    return messages

def generate_main_idea_summarizer_prompt(essay_data: Dict[str, Any], essay_type: str = None, few_shot_examples: List[Dict[str, Any]] = None) -> List[Message]:
    """
    生成文章中心思想/主要内容概括器的提示 (Expert 2: Essay Essence Extractor)
    """
    title = essay_data.get("title", "无标题")
    content = essay_data.get("content", "无内容")
    grade = essay_data.get("grade", "未知")
    essay_type_info = f"已识别的作文类型：{essay_type}" if essay_type else ""
    requirement = essay_data.get("requirement", "无具体要求")

    # --- Revised Prompt ---
    system_prompt = f"""
你的角色: 中小学核心立意提炼评估师。快速把握文章核心。

你的任务: 阅读{grade}年级作文，精炼概括核心内容和中心思想，并评估标题契合度与主旨清晰度。

{essay_type_info}

分析要点:

核心内容/事件概括: 用不超过 50 字客观陈述主要事件、描写对象或核心观点 (只述事实)。
中心思想/主旨提炼**: 用一句话明确作者最核心的情感、观点或感悟。 若模糊，请指出 "中心思想模糊不清" 或 "未能提炼出明确的中心思想"。
标题与内容契合度评估: 评估标题"{title}"与核心内容/主旨的匹配度。 选其一并简述理由: 高度契合 | 基本契合 | 部分契合 | 不太契合 | 完全无关。
中心思想表达清晰度评估: 评估主旨传达的明确性。 选其一并简述理由: 非常清晰 | 比较清晰 | 有些模糊 | 非常混乱/缺失。
输出要求:
生成《作文主旨与核心内容分析报告》，包含以下四部分，内容对应上述要点：
一、核心内容/事件概括:
二、中心思想/主旨提炼:
三、标题与内容契合度: (评级 + 理由)
四、中心思想表达清晰度: (评级 + 理由)

强调: 客观、精炼、准确，基于文本证据。
"""

    user_prompt = f"""
请根据上述指令，分析以下{grade}年级学生作文的主旨与核心内容：

作文标题： {title}
作文要求原文:
{requirement}
作文内容：
{content}

请生成《作文主旨与核心内容分析报告》。
"""
    # --- End Revised Prompt ---

    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_prompt)
    ]
    return messages

def generate_content_task_analyzer_prompt(essay_data: Dict[str, Any], prompt_interpretation: str, main_idea_summary: str, essay_type: str = None, few_shot_examples: List[Dict[str, Any]] = None) -> List[Message]:
    """
    生成内容要点与任务关联分析器的提示 (Expert 3: Relevance Alignment Verifier)
    """
    content = essay_data.get("content", "无内容")
    grade = essay_data.get("grade", "未知")
    essay_type_info = f"已识别的作文类型：{essay_type}" if essay_type else ""

    # --- Revised Prompt ---
    system_prompt = f"""
你的角色: 中小学作文切题度符合性校验师。严谨核对作文是否满足题目要求清单。

你的任务: 基于《作文题目要求精准解读报告》和《作文主旨与核心内容分析报告》，系统评估{grade}年级作文内容与题目要求的符合度，定位符合与偏离之处，给出初步切题度评级。

{essay_type_info}

评估步骤:

核心任务完成度评估: 对照解读报告中的"核心写作任务"，判断作文是否完成。 选其一并详述依据 (引用解读报告任务描述和作文核心内容/主旨对比)：完全符合 | 基本符合 | 部分符合 | 不符合。
限制条件符合度检查: 对照解读报告"关键限制条件"，逐一检查作文。 列表回应，对每项给出判断 (符合/不符合/部分符合/不适用/无法判断) 及简要理由/证据。
隐含要求响应度分析: 对照解读报告"隐含要求与评价倾向"，评估作文响应程度。 选总体评价并说明依据 (结合主旨报告和作文细节)：高度响应 | 基本响应 | 部分响应 | 几乎未响应 | 无明显隐含要求。
切题亮点与偏题点定位: 找出1-2个最符合要求的亮点，引用原文说明。 找出1-2个最明显偏离要求的点，引用原文说明。 若无偏题点，明确说明 "未发现明显偏离"。
切题度初步评级: 综合以上评估，给出初步切题度等级 (优秀/较好/一般/合格/不合格)，并简述主要理由。
输出要求:
生成《作文内容与任务符合性校验报告》，包含以下五部分，内容对应上述步骤：
一、核心任务完成度评估: (评级 + 依据)
二、限制条件符合度检查: (逐条检查结果 + 理由/证据)
三、隐含要求响应度分析: (评级 + 说明)
四、切题亮点与偏题点定位: (亮点原文+说明，偏题点原文+说明，或无偏题说明)
五、切题度初步评级: (评级 + 主要理由)

强调: 严谨对照，基于证据，具体指出，客观判断。
"""

    user_prompt = f"""
请根据上述指令，对以下{grade}年级学生作文进行内容与任务符合性校验：

输入文件1: 《作文题目要求精准解读报告》
{prompt_interpretation}

输入文件2: 《作文主旨与核心内容分析报告》
{main_idea_summary}

作文原文 (供校验核对):
{content}

请生成《作文内容与任务符合性校验报告》。
"""
    # --- End Revised Prompt ---

    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_prompt)
    ]
    return messages

def generate_material_relevance_checker_prompt(essay_data: Dict[str, Any], prompt_interpretation: str, main_idea_summary: str, essay_type: str = None, few_shot_examples: List[Dict[str, Any]] = None) -> List[Message]:
    """
    生成素材选择与切题度检查器的提示 (Expert 4: Material-Theme Alignment Auditor)
    """
    content = essay_data.get("content", "无内容")
    grade = essay_data.get("grade", "未知")
    essay_type_info = f"已识别的作文类型：{essay_type}" if essay_type else ""

    # 格式化 Few-Shot 示例
    few_shot_prompt = format_few_shot_examples(few_shot_examples)

    # --- Revised Prompt ---
    system_prompt = f"""
角色: 中小学作文素材主题适配评审员。审视素材是否指向主题。

任务: 分析{grade}年级作文素材（事例、描写、细节等），评估其与题目要求（基于解读报告）和文章主旨（基于主旨报告）的相关性与有效性。

{essay_type_info}

评估要点:

素材与核心任务/主旨的关联度评估: 分析主要素材是否直接服务于"核心写作任务"和"中心思想"。 选总体评价并说明理由（主要素材如何支撑或未能支撑）：高度相关 | 基本相关 | 部分相关 | 关联度低 | 基本无关。
素材的典型性与适切性评估: 评估素材对表达主旨/完成任务是否典型、具体、恰当（考虑{grade}年级）。 给出总体评价，并举例说明（正面或反面）。
最相关素材示例: 找出1个最能紧扣主题、服务核心任务/主旨的素材/段落。 简述素材，引用原文(约20-50字)，说明其相关性。
最不相关/低效素材示例: 找出1个关联最弱、作用最小或产生干扰的素材/部分 (若存在)。 简述内容，引用原文(约20-50字)，说明其为何不相关/低效。 若无，明确说明 "未发现明显不相关或低效素材"。
素材服务主旨的总体有效性评级: 综合评定素材支撑主题的整体有效性 (优秀/较好/一般/合格/不合格)，并简述主要理由。
输出要求:
生成《作文素材主题契合度审核报告》，包含以下五部分，内容对应上述要点：
一、素材与核心任务/主旨的关联度: (评级 + 理由)
二、素材的典型性与适切性: (评价 + 例子)
三、最相关素材示例: (描述 + 原文 + 说明)
四、最不相关/低效素材示例: (描述 + 原文 + 说明，或无此情况说明)
五、素材服务主旨的总体有效性评级: (评级 + 主要理由)

强调: 聚焦素材，双重对标（题目要求 & 文章主旨），具体分析，区分主次。
{few_shot_prompt}
"""

    user_prompt = f"""
请根据上述指令，审核以下{grade}年级学生作文中素材的主题契合度：

输入文件1: 《作文题目要求精准解读报告》
{prompt_interpretation}

输入文件2: 《作文主旨与核心内容分析报告》
{main_idea_summary}

作文原文 (供审核):
{content}

请生成《作文素材主题契合度审核报告》。
"""
    # --- End Revised Prompt ---

    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_prompt)
    ]
    return messages

def generate_comprehensive_reviewer_prompt(essay_data: Dict[str, Any], analysis_reports: Dict[str, str], essay_type: str = None, few_shot_examples: List[Dict[str, Any]] = None) -> List[Message]:
    """
    生成综合评审与定级员的提示 (Expert 5: Final Relevance Adjudicator)
    """
    grade = essay_data.get("grade", "未知")
    requirement = essay_data.get("requirement", "无具体要求")
    title = essay_data.get("title", "无标题")
    content = essay_data.get("content", "无内容") # 提供原文供最终核对
    essay_type_info = f"已识别的作文类型：{essay_type}" if essay_type else ""

    # 格式化 Few-Shot 示例
    few_shot_prompt = format_few_shot_examples(few_shot_examples)

    # 整理前面的分析报告文本 (保持不变)
    reports_text = ""
    reports_text += "--- 《作文题目要求精准解读报告》(Prompt Precision Analysis Report) ---\n" + analysis_reports.get("作文要求精准解读报告", "报告缺失") + "\n\n"
    reports_text += "--- 《作文主旨与核心内容分析报告》(Essay Essence Analysis Report) ---\n" + analysis_reports.get("作文主旨与核心内容分析报告", "报告缺失") + "\n\n"
    reports_text += "--- 《作文内容与任务符合性校验报告》(Content-Task Alignment Verification Report) ---\n" + analysis_reports.get("作文内容与任务符合性校验报告", "报告缺失") + "\n\n"
    reports_text += "--- 《作文素材主题契合度审核报告》(Material-Theme Alignment Audit Report) ---\n" + analysis_reports.get("作文素材主题契合度审核报告", "报告缺失") + "\n\n"
    if "作文类型识别报告" in analysis_reports:
        reports_text += "--- 《作文类型识别报告》(Essay Type Classification Report) ---\n" + analysis_reports.get("作文类型识别报告", "报告缺失") + "\n\n"

    # --- Revised Prompt ---
    system_prompt = f"""
你的角色是: **高度权威且极其严谨的**中小学作文评审终审裁定组。负责最终等级判定，**严守标准，杜绝评级虚高**。

你的任务: 仔细审阅下方提供的作文信息、题目要求、以及四份（或五份，包含类型识别）专家分析报告。基于对所有信息的**综合与批判性理解**，对照官方标准，对该{grade}年级学生作文的切题度做出**精准、公正**的最终等级评定，并提供清晰、有力的定级依据。

{essay_type_info}

官方切题度等级标准 (**必须严格遵守，作为唯一评判依据**):
优秀(5分): 90%完成题目核心任务要求，作文**高度契合**写作要求的所有方面。立意精准深刻，中心思想**极其**明确、突出。选材**完全**服务于中心，恰到好处，与主题高度统一、相得益彰，无任何偏离性内容，素材均与主题形成强关联。
较好(4分): 完成主要写作任务(≥80%)，作文**充分满足**写作要求。立意准确，中心思想**很**明确、清晰。选材**紧密**围绕中心，较为恰当，与主题结合良好，主要素材有效支撑主题。
一般(3分): 基本完成任务(≥70%) ，作文**基本符合**写作要求。立意**大致准确**，中心思想**尚属**清晰，但可能不够突出或集中。选材**基本**能服务于中心，但可能存在部分关联性不强或不够典型的材料。整体表现**可以接受**。
合格(2分): 任务完成度≥60%，作文**最低限度**满足了写作的基本要求，**未完全脱离**主题范围。但立意**不够精准**，中心思想**不明确**、**不突出**，需要费力寻找。选材与主题的**关联度低**，**不能有效支撑**中心。**存在明显偏题、跑题或对要求理解不到位的情况，中心思想需反复推敲才能关联题目**。
不合格(1分): 核心任务未完成，作文**严重偏离**写作要求。立意**错误或极其模糊**。中心思想**完全不清晰**，无法把握，中心思想与题目要求相悖。选材与主题**毫不相干**或**完全相悖**。**完全未理解题目要求，或存在根本性跑题**。

核心裁决原则:
1.  **任务完成度优先**: 文章是否完成了题目的核心任务？是否遵守了所有明确的限制条件（如文体、字数、特定内容要求等）？**任何核心任务未完成或限制条件未遵守，原则上不能评为“一般”及以上。**
2.  **中心思想是关键**: 中心思想是否明确？是否紧扣题目要求？**中心思想模糊、偏离是判定为“合格”或“不合格”的关键指标。**
3.  **内容与材料的相关性**: 选择的内容和材料是否**直接且有效**地支撑中心思想？**材料空泛、堆砌、与主题关联弱是拉低评分的重要因素。**
4.  **严格把控界限**: **特别注意“合格”与“不合格”、“一般”与“合格”的界限。** 对于可能处于边界的作文，**必须从严审视**其是否达到了更高一级标准的所有基本要求。**绝不允许因为个别语句尚可或结构完整等非切题因素而将“合格”作文提升至“一般”，或将“不合格”作文评为“合格”。**
5.  **综合判断，拒绝“平均主义”**: 依据作文整体表现与标准的符合度进行判断，而不是简单地对专家报告进行平均。要敢于推翻明显不合理的初步意见。

评定步骤:
1.  **信息整合与审视**: 快速把握四份（或五份）报告的关键信息（题目核心任务/限制，文章内容/主旨/清晰度，内容符合度/硬伤，素材有效性），**并对其结论的合理性进行独立判断**。
2.  **对照标准严格裁决**:
    *   **首先判断是否达到“合格”标准**: 检查是否存在严重偏题、中心思想完全模糊、材料毫不相关等“不合格”的硬伤。若存在，直接判定为“不合格”。
    *   **若非“不合格”，再判断是否满足“合格”的所有描述**: 检查中心思想是否不明确、材料关联度是否低等。若符合，判定为“合格”。
    *   **只有确认满足“合格”标准后，才继续向上评估**：“一般”、“较好”、“优秀”。**逐级比对，选择最贴切的等级描述。**
3.  **阐述定级依据**:
    *   清晰说明为何评定为该等级，**必须**结合专家报告的关键发现、**作文原文的具体表述 (必要时可引用)** 以及官方标准条文。
    *   提供**至少2-3个最关键的、决定性的证据**（优点或缺点）来支撑你的定级。
    *   **必须明确指出是什么核心要素（或核心缺陷）导致了这个最终评级。** 如果评为“一般”或以下，需说明**主要扣分点**在哪里；如果评为“较好”或“优秀”，需说明**最突出的亮点**是什么，且确保其符合该等级描述。

输出要求:
生成《切题度最终裁决报告》，包含以下三部分：
一、最终分类: (明确选择：优秀/较好/一般/合格/不合格)
二、定级理由: (详细说明，引用报告发现、原文内容（可选）和官方标准，**论证充分，逻辑清晰**)
三、决定性因素: (明确指出**1个**最关键的因素，解释其为何具有决定性作用)
强调: **判定的精准性与权威性**、**标准的刚性执行**、**论证的严密性与说服力**。**聚焦切题度本身，不受其他写作维度（如文采、错别字、部分结构瑕疵）的直接干扰，除非它们严重影响了主题的表达和理解。**
{few_shot_prompt}
"""

    user_prompt = f"""
请根据上述指令，对以下{grade}年级学生作文的切题度进行最终裁决：

作文题目要求原文: {requirement}
作文标题: {title}
作文原文 (供最终核对):
{content}

输入分析报告汇总:
{reports_text}

请生成《切题度最终裁决报告》。
"""
    # --- End Revised Prompt ---

    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_prompt)
    ]
    return messages

# !!! 注意：这个函数保持不变，不添加 Few-Shot 示例 !!!
def generate_final_classification_prompt(essay_data: Dict[str, Any], comprehensive_review: str) -> List[Message]:
    """
    生成最终分类确认的提示 (Final Classification Extractor)
    角色：评审结果记录员。
    任务：从《切题度最终裁决报告》中准确提取最终分类结果词。
    """
    grade = essay_data.get("grade", "未知")
    requirement = essay_data.get("requirement", "无具体要求")
    title = essay_data.get("title", "无标题")

    # --- Revised Prompt (Minimal changes, already concise) ---
    system_prompt = """
角色: 中小学作文评审结果记录员。

任务: 从下方《切题度最终裁决报告》中，准确提取 "一、最终分类 (Final Classification):" 后面的分类结果词。

结果必须是以下之一： 优秀 | 较好 | 一般 | 合格 | 不合格

输出要求:
仅输出上述五个词语中的一个。
无任何其他文字、解释、标点、空格或换行。
"""

    user_prompt = f"""
请从以下《切题度最终裁决报告》中提取最终的分类结果：

学生信息： {grade}年级
作文题目要求： {requirement}
作文标题： {title}

《切题度最终裁决报告》：
{comprehensive_review}

请严格按要求，只输出最终分类词（例如：较好）。
"""
    # --- End Revised Prompt ---

    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_prompt)
    ]
    return messages
# --- Prompt Revisions End ---

# 改进的BERT模型部分 (保持不变)
class AdvancedEssayDataset(Dataset):
    def __init__(self, essays, tokenizer, max_length=512):
        self.essays = essays
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.essays)

    def __getitem__(self, idx):
        essay = self.essays[idx]

        # 提取更多特征
        requirement = essay.get('requirement', '')
        grade = essay.get('grade', '')
        title = essay.get('title', '')
        content = essay.get('content', '')

        # 分别处理不同部分的文本
        req_encoding = self.tokenizer(
            requirement,
            truncation=True,
            max_length=128,
            padding='max_length',
            return_tensors='pt'
        )

        title_encoding = self.tokenizer(
            title,
            truncation=True,
            max_length=32,
            padding='max_length',
            return_tensors='pt'
        )

        content_encoding = self.tokenizer(
            content,
            truncation=True,
            max_length=384, # 调整内容部分的最大长度以适应总长度 512 = 128 + 32 + 352 -> 384 for safety
            padding='max_length',
            return_tensors='pt'
        )

        # 合并所有输入特征 - 注意！这里直接拼接可能超长
        # 先只返回分开的，模型内部再处理
        # 确保每个部分长度加起来不超过 max_length
        combined_len = req_encoding['input_ids'].shape[1] + title_encoding['input_ids'].shape[1] + content_encoding['input_ids'].shape[1]
        if combined_len > self.max_length:
             # 如果拼接后超长，优先缩短 content 部分
             new_content_len = self.max_length - req_encoding['input_ids'].shape[1] - title_encoding['input_ids'].shape[1]
             if new_content_len < 64: # 保证内容至少有一定长度
                 new_content_len = 64
                 # 如果内容缩短后还不够，再缩短 requirement
                 new_req_len = self.max_length - new_content_len - title_encoding['input_ids'].shape[1]
                 if new_req_len < 32: new_req_len = 32 # 保证 requirement 最小长度
                 req_encoding = self.tokenizer(requirement, truncation=True, max_length=new_req_len, padding='max_length', return_tensors='pt')
             content_encoding = self.tokenizer(content, truncation=True, max_length=new_content_len, padding='max_length', return_tensors='pt')

        # 确保拼接后填充到 max_length
        input_ids_list = [req_encoding['input_ids'].squeeze(), title_encoding['input_ids'].squeeze(), content_encoding['input_ids'].squeeze()]
        attention_mask_list = [req_encoding['attention_mask'].squeeze(), title_encoding['attention_mask'].squeeze(), content_encoding['attention_mask'].squeeze()]

        current_len = sum(t.shape[0] for t in input_ids_list)
        padding_needed = self.max_length - current_len

        if padding_needed < 0:
            # This case should be handled by the truncation above, but as a safeguard:
            logger.warning(f"Essay ID {essay.get('id', 'N/A')}: Calculated padding needed is negative ({padding_needed}). Review truncation logic.")
            # Force truncation on content again if needed
            input_ids_list[-1] = input_ids_list[-1][:input_ids_list[-1].shape[0] + padding_needed]
            attention_mask_list[-1] = attention_mask_list[-1][:attention_mask_list[-1].shape[0] + padding_needed]
            padding_needed = 0
        elif padding_needed > 0:
             # Add padding tensor to the end
             input_ids_list.append(torch.zeros(padding_needed, dtype=torch.long))
             attention_mask_list.append(torch.zeros(padding_needed, dtype=torch.long))


        input_ids = torch.cat(input_ids_list, dim=0)
        attention_mask = torch.cat(attention_mask_list, dim=0)

        item = {
            'id': str(essay.get('id', f'index_{idx}')),
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            # 保留分开的编码供模型内部使用
            'req_input_ids': req_encoding['input_ids'].squeeze(),
            'req_attention_mask': req_encoding['attention_mask'].squeeze(),
            'title_input_ids': title_encoding['input_ids'].squeeze(),
            'title_attention_mask': title_encoding['attention_mask'].squeeze(),
            'content_input_ids': content_encoding['input_ids'].squeeze(),
            'content_attention_mask': content_encoding['attention_mask'].squeeze(),
        }

        # 使用 'classification' 字段作为真实标签 (用于训练和评估)
        true_label_name = essay.get('classification')
        if true_label_name and true_label_name in CLASSIFICATION_LABELS:
            label = CLASSIFICATION_LABELS[true_label_name] - 1  # 模型需要0-4的标签
            item['labels'] = torch.tensor(label, dtype=torch.long)

        return item


# 高级注意力机制
class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        assert self.head_dim * num_heads == hidden_size, "Hidden size must be divisible by num_heads"

        self.query = torch.nn.Linear(hidden_size, hidden_size)
        self.key = torch.nn.Linear(hidden_size, hidden_size)
        self.value = torch.nn.Linear(hidden_size, hidden_size)
        self.out = torch.nn.Linear(hidden_size, hidden_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # Linear projections
        query = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, S, hd)
        key = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, S, hd)
        value = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, S, hd)

        # Scaled dot-product attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) # (B, nh, S, S)

        if mask is not None:
            # Expand mask: (B, S) -> (B, 1, 1, S) for broadcasting
            # Ensure mask has the correct dimensions
            if mask.dim() == 2: # Expected shape [B, S]
                expanded_mask = mask.unsqueeze(1).unsqueeze(2)
            elif mask.dim() == 3 and mask.shape[1] == 1: # Shape [B, 1, S]
                expanded_mask = mask.unsqueeze(2)
            elif mask.dim() == 4 and mask.shape[1] == 1 and mask.shape[2] == 1: # Shape [B, 1, 1, S]
                expanded_mask = mask
            else:
                raise ValueError(f"Unexpected attention mask shape: {mask.shape}")

            scores = scores.masked_fill(expanded_mask == 0, -1e9) # Use large negative value

        attn_weights = torch.nn.functional.softmax(scores, dim=-1) # (B, nh, S, S)
        context = torch.matmul(attn_weights, value) # (B, nh, S, hd)

        # Concatenate heads and put through final linear layer
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim) # (B, S, H)
        output = self.out(context) # (B, S, H)

        return output

# 特征融合模块
class FeatureFusionModule(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeatureFusionModule, self).__init__()
        self.linear1 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)
        self.dropout = torch.nn.Dropout(0.1)
        self.gelu = torch.nn.GELU()

    def forward(self, x1, x2):
        # 拼接两个特征 - 确保维度是 [B, H]
        if x1.dim() > 2 or x2.dim() > 2:
             # If inputs are sequence outputs, apply pooling first
             # This assumes pooling is needed, adjust if x1/x2 are already pooled
             if x1.dim() == 3: x1 = x1[:, 0] # Use CLS token representation
             if x2.dim() == 3: x2 = x2[:, 0] # Use CLS token representation
        combined = torch.cat([x1, x2], dim=-1) # Should be [B, 2*H]

        # 通过两层MLP进行融合
        fused = self.linear1(combined)
        fused = self.gelu(fused)
        fused = self.dropout(fused)
        fused = self.linear2(fused)

        # 残差连接和层归一化
        output = self.layer_norm(fused + x1) # Residual connection with x1

        return output

# 改进的BERT模型架构
# 改进的BERT模型架构
class EnhancedBertEssayModel(torch.nn.Module):
    def __init__(self, bert_model_name="bert-base-chinese", num_classes=len(CLASSIFICATION_LABELS)):
        super(EnhancedBertEssayModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.hidden_size = self.bert.config.hidden_size

        # 多头自注意力机制
        self.self_attention = MultiHeadSelfAttention(self.hidden_size, num_heads=8)

        # 不同部分的特征提取器 (用于分开编码时的池化输出)
        self.req_encoder_head = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.title_encoder_head = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.content_encoder_head = torch.nn.Linear(self.hidden_size, self.hidden_size)


        # 特征融合模块 (用于融合池化后的特征)
        self.fusion_req_title = FeatureFusionModule(self.hidden_size)
        self.fusion_title_content = FeatureFusionModule(self.hidden_size)
        self.fusion_final = FeatureFusionModule(self.hidden_size)


        # 双向LSTM层 (处理序列输出)
        self.lstm = torch.nn.LSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size // 2, # Reduce hidden size for LSTM
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if 2 > 1 else 0 # Add dropout only if num_layers > 1
        )

        # 高级特征提取网络 (作用于LSTM输出或融合后的特征)
        self.feature_network = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_size, self.hidden_size), # Input matches LSTM output or fused features
            torch.nn.LayerNorm(self.hidden_size),
            torch.nn.GELU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(self.hidden_size, self.hidden_size // 2),
            torch.nn.LayerNorm(self.hidden_size // 2),
            torch.nn.GELU(),
            torch.nn.Dropout(0.1)
        )

        # 分类头
        self.classifier = torch.nn.Linear(self.hidden_size // 2, num_classes) # Matches output of feature_network

        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        # 初始化线性层
        for module in self.modules():
            if isinstance(module, torch.nn.Linear):
                torch.nn.init.xavier_normal_(module.weight)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, torch.nn.LSTM):
                for name, param in module.named_parameters():
                    if 'weight_ih' in name:
                        torch.nn.init.xavier_uniform_(param.data)
                    elif 'weight_hh' in name:
                        torch.nn.init.orthogonal_(param.data)
                    elif 'bias' in name:
                        param.data.fill_(0)
                        # Setting forget bias to 1
                        n = param.size(0)
                        start, end = n // 4, n // 2
                        param.data[start:end].fill_(1.)
            elif isinstance(module, torch.nn.LayerNorm):
                 module.bias.data.zero_()
                 module.weight.data.fill_(1.0)


    def forward(self, input_ids, attention_mask, req_input_ids=None, req_attention_mask=None,
                title_input_ids=None, title_attention_mask=None,
                content_input_ids=None, content_attention_mask=None):

        # 判断是使用拼接输入还是分开输入
        use_separate_inputs = (req_input_ids is not None and req_attention_mask is not None and
                               title_input_ids is not None and title_attention_mask is not None and
                               content_input_ids is not None and content_attention_mask is not None)

        if use_separate_inputs:
            # 分别编码并获取池化输出
            req_outputs = self.bert(input_ids=req_input_ids, attention_mask=req_attention_mask, return_dict=True)
            req_features = self.req_encoder_head(req_outputs.pooler_output) # [B, H]

            title_outputs = self.bert(input_ids=title_input_ids, attention_mask=title_attention_mask, return_dict=True)
            title_features = self.title_encoder_head(title_outputs.pooler_output) # [B, H]

            content_outputs = self.bert(input_ids=content_input_ids, attention_mask=content_attention_mask, return_dict=True)
            content_features = self.content_encoder_head(content_outputs.pooler_output) # [B, H]

            # 特征融合
            fused_req_title = self.fusion_req_title(req_features, title_features)
            fused_title_content = self.fusion_title_content(title_features, content_features)
            combined_features = self.fusion_final(fused_req_title, fused_title_content) # [B, H]
            final_features_input = combined_features # [B, H]

        else:
            # 使用拼接输入
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            sequence_output = outputs.last_hidden_state # [B, S, H]

            # 应用自注意力机制
            # Check attention mask shape
            if attention_mask.dim() == 3 and attention_mask.shape[1] == 1:
                attention_mask_sa = attention_mask.squeeze(1) # Ensure mask is [B, S] for SA
            elif attention_mask.dim() == 2:
                 attention_mask_sa = attention_mask
            else:
                 # Fallback or raise error for unexpected mask shape
                 logger.warning(f"Unexpected attention mask shape for self-attention: {attention_mask.shape}. Using as is.")
                 attention_mask_sa = attention_mask

            attended_output = self.self_attention(sequence_output, attention_mask_sa) # [B, S, H]

            # 应用BiLSTM
            # LSTM expects packed sequence or padded sequence, attended_output is padded
            lstm_output, _ = self.lstm(attended_output) # [B, S, H] (because bidirectional=True)

            # 池化LSTM输出 (例如取最后一个时间步的隐藏状态 或 最大/平均池化)
            # Option 1: Max pooling over sequence dimension
            pooled_output = torch.max(lstm_output, dim=1)[0] # [B, H]
            
            final_features_input = pooled_output # [B, H]

        # 确保 final_features_input 的维度是 [B, H]
        if final_features_input.dim() == 3:
            # 如果是 [B, 1, H] 或 [B, S, H]，需要处理为 [B, H]
            if final_features_input.shape[1] == 1:
                final_features_input = final_features_input.squeeze(1)  # 从 [B, 1, H] 变为 [B, H]
            else:
                # 如果有多个序列步骤，取最大值
                final_features_input = torch.max(final_features_input, dim=1)[0]  # [B, H]

        # 通过特征网络提取高级特征 (Input should be [B, H])
        if final_features_input.dim() != 2 or final_features_input.shape[-1] != self.hidden_size:
             raise ValueError(f"Input to feature_network has unexpected shape: {final_features_input.shape}. Expected [B, {self.hidden_size}]")
        enhanced_features = self.feature_network(final_features_input) # [B, H/2]

        # 分类
        logits = self.classifier(enhanced_features) # [B, num_classes]

        return logits



# 无监督学习的对比损失函数
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)

    def forward(self, features):
        batch_size = features.size(0)
        if batch_size <= 1:
            # Contrastive loss requires at least 2 samples
            return torch.tensor(0.0, device=features.device, requires_grad=True)

        # L2 normalize features
        features = torch.nn.functional.normalize(features, p=2, dim=1)

        # Create similarity matrix
        # similarity_matrix[i, j] = cos(features[i], features[j])
        similarity_matrix = torch.matmul(features, features.t()) # [B, B]

        # Create labels: positive pairs are (i, i)
        labels = torch.arange(batch_size, device=features.device) # [B]

        # Mask out diagonal (self-similarity) to consider only (i, j) where i != j for loss calculation ?
        # Let's follow standard InfoNCE: diagonal is positive, others are negative
        #logits = similarity_matrix / self.temperature # Scale similarities by temperature

        # Create mask to exclude self-similarity from negative examples? No, InfoNCE uses all pairs
        # mask = torch.eye(batch_size, dtype=torch.bool, device=features.device)
        # logits = logits.masked_fill(mask, -float('inf')) # Mask diagonal for negative selection

        # Calculate loss using cross-entropy
        # The goal is to make the similarity of (i, i) high compared to (i, j) where j!=i
        # CrossEntropyLoss expects (N, C) input and (N) target
        # Here, for each sample i, the "correct class" is sample i itself.
        loss = torch.nn.functional.cross_entropy(similarity_matrix / self.temperature, labels)

        return loss


# 改进的无监督训练函数
def train_bert_model_unsupervised(data, epochs=5, batch_size=16, learning_rate=2e-5):
    """无监督训练BERT模型用于辅助评估"""
    if not data:
        logger.error("No data provided for BERT model.")
        return None, None

    # 初始化tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

    # 创建数据集
    dataset = AdvancedEssayDataset(data, tokenizer)
    # Filter out potential None items from dataset (if __getitem__ can return None)
    # dataset = [item for item in dataset if item is not None]
    if not dataset:
         logger.error("AdvancedEssayDataset created no valid items.")
         return None, None

    # Use drop_last=True if batch size > 1 to avoid issues with ContrastiveLoss
    drop_last_batch = batch_size > 1
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last_batch)
    if len(dataloader) == 0:
        logger.error(f"DataLoader is empty. Check batch size ({batch_size}) and dataset size ({len(dataset)}).")
        return None, tokenizer # Return tokenizer as it's loaded

    # 初始化增强模型
    model = EnhancedBertEssayModel()
    model.to(device)

    # 优化器和学习率调度器
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)

    # 计算总训练步数用于学习率调度
    total_steps = len(dataloader) * epochs

    # 创建学习率调度器
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),  # 10% 的步数用于预热
        num_training_steps=total_steps
    )

    # 对比损失函数
    contrastive_loss_fn = ContrastiveLoss(temperature=0.07) # InfoNCE style loss

    # 训练循环
    logger.info("Starting unsupervised BERT model training...")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        batches_processed = 0

        # 训练阶段
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch in progress_bar:
            # Skip incomplete batches if drop_last=False and contrastive loss is used
            if batch['input_ids'].size(0) < 2 and batch_size > 1:
                logger.debug(f"Skipping batch with size {batch['input_ids'].size(0)} (< 2) for contrastive loss.")
                continue

            # 提取所有输入
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            req_input_ids = batch.get('req_input_ids').to(device) if 'req_input_ids' in batch else None
            req_attention_mask = batch.get('req_attention_mask').to(device) if 'req_attention_mask' in batch else None
            title_input_ids = batch.get('title_input_ids').to(device) if 'title_input_ids' in batch else None
            title_attention_mask = batch.get('title_attention_mask').to(device) if 'title_attention_mask' in batch else None
            content_input_ids = batch.get('content_input_ids').to(device) if 'content_input_ids' in batch else None
            content_attention_mask = batch.get('content_attention_mask').to(device) if 'content_attention_mask' in batch else None

            optimizer.zero_grad()

            # --- Get features for contrastive loss ---
            # We need a single feature vector per essay for contrastive loss.
            # Pass data through the model to get the input to the classifier.
            # Use the standard forward pass logic (no separate inputs for unsupervised)
            outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            sequence_output = outputs.last_hidden_state

            # Attention + LSTM
            if attention_mask.dim() == 3 and attention_mask.shape[1] == 1:
                 attention_mask_att = attention_mask.squeeze(1)
            else:
                 attention_mask_att = attention_mask

            try:
                attended_output = model.self_attention(sequence_output, attention_mask_att)
                lstm_output, _ = model.lstm(attended_output)
            except ValueError as e:
                 logger.error(f"Error during self-attention or LSTM in training: {e}")
                 logger.error(f"Sequence output shape: {sequence_output.shape}, attention mask shape: {attention_mask_att.shape}")
                 continue # Skip this batch

            # Pooling
            features_before_mlp = torch.max(lstm_output, dim=1)[0] # Max pooling

            # Final feature extraction (input to classifier)
            contrastive_features = model.feature_network(features_before_mlp) # [B, H/2]
            # --- End Feature Extraction ---


            # 计算对比损失
            loss = contrastive_loss_fn(contrastive_features)

            # 反向传播 (only if loss is valid)
            if loss.requires_grad: # Check if loss requires grad (might be 0 if batch size was 1)
                loss.backward()

                # 梯度裁剪，防止梯度爆炸
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                optimizer.step()
                scheduler.step()

            total_loss += loss.item()
            batches_processed += 1

            # 更新进度条
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'}) # Format loss

        if batches_processed > 0:
            avg_train_loss = total_loss / batches_processed
            logger.info(f"Epoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")
        else:
            logger.warning(f"Epoch {epoch+1} - No batches processed.")

    # 保存模型
    try:
        os.makedirs(os.path.dirname(bert_model_path), exist_ok=True)
        torch.save(model.state_dict(), bert_model_path)
        logger.info(f"Unsupervised BERT Model saved to {bert_model_path}")
    except Exception as e:
        logger.error(f"Failed to save BERT model: {e}")

    return model, tokenizer


# 改进的预测函数
def predict_with_bert(model, tokenizer, essay_data):
    """使用增强的BERT模型预测作文分类"""
    if model is None or tokenizer is None:
        logger.warning("BERT model or tokenizer not available for prediction.")
        return None

    # 创建数据集
    dataset = AdvancedEssayDataset([essay_data], tokenizer) # Use the advanced dataset

    # 确保数据集非空
    if len(dataset) == 0:
        logger.warning("Empty dataset for BERT prediction.")
        return None

    # 获取输入数据
    item = dataset[0] # Get the single item
    # Don't create a dataloader for single prediction, just use the item directly

    # Move tensors to device and add batch dimension
    input_ids = item['input_ids'].unsqueeze(0).to(device)
    attention_mask = item['attention_mask'].unsqueeze(0).to(device)
    req_input_ids = item.get('req_input_ids').unsqueeze(0).to(device) if item.get('req_input_ids') is not None else None
    req_attention_mask = item.get('req_attention_mask').unsqueeze(0).to(device) if item.get('req_attention_mask') is not None else None
    title_input_ids = item.get('title_input_ids').unsqueeze(0).to(device) if item.get('title_input_ids') is not None else None
    title_attention_mask = item.get('title_attention_mask').unsqueeze(0).to(device) if item.get('title_attention_mask') is not None else None
    content_input_ids = item.get('content_input_ids').unsqueeze(0).to(device) if item.get('content_input_ids') is not None else None
    content_attention_mask = item.get('content_attention_mask').unsqueeze(0).to(device) if item.get('content_attention_mask') is not None else None


    # 预测
    model.eval()
    with torch.no_grad():
        try:
            # Pass all available inputs to the model's forward method
            logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                req_input_ids=req_input_ids,
                req_attention_mask=req_attention_mask,
                title_input_ids=title_input_ids,
                title_attention_mask=title_attention_mask,
                content_input_ids=content_input_ids,
                content_attention_mask=content_attention_mask
            )
        except Exception as e:
            logger.error(f"Error during BERT model forward pass for ID {essay_data.get('id')}: {e}", exc_info=True)
            return None # Return None on prediction error


        # 应用softmax获取概率分布
        probabilities = torch.softmax(logits, dim=1)

        # 获取最高概率的类别和置信度
        confidence, predicted = torch.max(probabilities, 1)

        # 将预测结果转换为分类标签 (1-5)
        label_idx = predicted.item()  # 0-4
        predicted_score = label_idx + 1
        classification = SCORE_TO_LABEL.get(predicted_score, "一般") # Default to Average

        # 记录预测的置信度
        confidence_val = confidence.item()
        logger.debug(f"BERT prediction: {classification} (score: {predicted_score}) with confidence: {confidence_val:.4f}")

        # 获取所有类别的概率分布
        all_probs = probabilities[0].cpu().numpy()
        prob_str = ", ".join([f"{SCORE_TO_LABEL.get(i+1, 'N/A')}: {prob:.4f}" for i, prob in enumerate(all_probs)])
        logger.debug(f"Class probabilities: {prob_str}")

        return classification

# 改进的模型加载/训练函数
def load_or_train_bert_model_unsupervised(data):
    """加载已有的BERT模型或使用无监督方式训练新模型"""
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    model = EnhancedBertEssayModel() # Use the enhanced model definition
    model_loaded_successfully = False

    # 尝试加载已有模型
    if os.path.exists(bert_model_path):
        try:
            logger.info(f"Loading existing BERT model from {bert_model_path}")
            model.load_state_dict(torch.load(bert_model_path, map_location=device))
            model.to(device)
            model_loaded_successfully = True
            logger.info("Enhanced BERT model loaded successfully.")
            return model, tokenizer
        except Exception as e:
            logger.error(f"Failed to load existing BERT model: {e}. Will attempt to train a new one.")
            # Clean up potentially corrupted file
            try:
                os.remove(bert_model_path)
                logger.info(f"Removed potentially corrupted model file: {bert_model_path}")
            except OSError as remove_error:
                logger.error(f"Error removing model file: {remove_error}")

    # 如果模型不存在或加载失败，尝试训练
    logger.info("Existing BERT model not found or failed to load. Preparing to train a new model using unsupervised learning.")

    # 检查数据是否有效
    valid_data = [d for d in data if isinstance(d, dict) and d.get('content')] # Check for content specifically
    # Further check for minimum required fields for AdvancedEssayDataset
    valid_data = [d for d in valid_data if d.get('requirement') is not None and d.get('title') is not None]


    if len(valid_data) < 5: # Need at least a few samples for training
        logger.warning(f"Insufficient valid data ({len(valid_data)} found, need at least 5 with req/title/content) to train a new BERT model. BERT auxiliary evaluation will be skipped.")
        return None, tokenizer # Return tokenizer anyway

    # 训练模型（无监督方式）
    model, tokenizer = train_bert_model_unsupervised(
        valid_data,
        epochs=3,          # Reduced epochs for faster training example
        batch_size=8,      # Adjusted batch size
        learning_rate=3e-5 # Adjusted learning rate
    )

    if model:
        model.to(device)
        logger.info("New unsupervised BERT model trained successfully.")
        return model, tokenizer
    else:
        logger.error("BERT model training failed.")
        return None, tokenizer


def load_processed_essays():
    """加载已处理的作文"""
    if os.path.exists(processed_essays_file):
        try:
            with open(processed_essays_file, 'r', encoding='utf-8') as f:
                # Check if file is empty before loading
                content = f.read()
                if not content:
                     logger.warning(f"Processed essays file '{processed_essays_file}' is empty. Returning empty dictionary.")
                     return {}
                return json.loads(content) # Use loads on the read content
        except json.JSONDecodeError as e:
            logger.error(f"Error decoding JSON from {processed_essays_file}: {e}. Returning empty dictionary.")
            # Optional: backup the corrupted file
            try:
                backup_path = processed_essays_file + f".corrupted_{int(time.time())}"
                os.rename(processed_essays_file, backup_path)
                logger.info(f"Backed up corrupted processed essays file to {backup_path}")
            except OSError as rename_e:
                logger.error(f"Could not backup corrupted file: {rename_e}")
            return {}
        except Exception as e:
            logger.error(f"Error loading processed essays from {processed_essays_file}: {e}. Returning empty dictionary.")
            return {}
    logger.info(f"Processed essays file '{processed_essays_file}' not found. Starting fresh.")
    return {}


def save_processed_essays(processed_essays):
    """保存已处理的作文"""
    try:
        # Create a temporary file path
        temp_file_path = processed_essays_file + ".tmp"
        with open(temp_file_path, 'w', encoding='utf-8') as f:
            json.dump(processed_essays, f, ensure_ascii=False, indent=2)

        # If write successful, replace the original file atomically
        os.replace(temp_file_path, processed_essays_file)
        logger.debug(f"Saved {len(processed_essays)} processed essays to {processed_essays_file}")
    except Exception as e:
        logger.error(f"Error saving processed essays to {processed_essays_file}: {e}")
        # Attempt to remove the temporary file if it exists
        if os.path.exists(temp_file_path):
            try:
                os.remove(temp_file_path)
            except OSError as remove_error:
                logger.error(f"Error removing temporary save file {temp_file_path}: {remove_error}")


# 主要评估流程 - 修改为同步版本
def evaluate_essay_with_agents(essay_data: Dict[str, Any], samples_data: List[Dict[str, Any]] = None, bert_model=None, tokenizer=None):
    """使用多智能体评估单篇作文的切题度 (同步版本)，并使用 Few-Shot 样本"""
    essay_id = essay_data.get("id", "unknown")
    start_time = time.time()
    logger.info(f"[{essay_id}] 开始评估 (包含 Few-Shot 示例)...")

    analysis_reports = {}
    llm_agent_classification = "一般"  # Agent链的初步分类
    final_classification = "一般"  # 最终裁决后的分类
    bert_classification = None
    error_occurred = False
    error_message = ""
    final_review_report = ""  # Store the final review report text
    essay_type = None  # 初始化作文类型

    try:
        # --- 新增：作文类型识别 ---
        logger.info(f"[{essay_id}] Agent 0: 识别作文类型...")
        essay_type_prompt = generate_essay_type_classifier_prompt(essay_data, few_shot_examples=samples_data) # 传递样本
        essay_type_response = call_llm_api(essay_type_prompt)
        if isinstance(essay_type_response, str) and "API_ERROR" in essay_type_response:
            logger.warning(f"API Error during Essay Type Classification: {essay_type_response}")
            essay_type = "记叙文"  # 默认类型
        else:
            essay_type = extract_essay_type(essay_type_response)
            analysis_reports["作文类型识别报告"] = essay_type_response
        logger.info(f"[{essay_id}] Agent 0: 完成. 识别作文类型为: '{essay_type}'.")
        time.sleep(random.uniform(0.2, 0.5))

        # --- Agent Chain Execution ---
        # 1. 作文要求解读 (Prompt Precision Analyst)
        logger.info(f"[{essay_id}] Agent 1: 精准解读作文要求...")
        prompt_interpretation_prompt = generate_prompt_interpretation_prompt(essay_data, essay_type, few_shot_examples=samples_data) # 传递样本
        prompt_interpretation = call_llm_api(prompt_interpretation_prompt)
        if isinstance(prompt_interpretation, str) and "API_ERROR" in prompt_interpretation: raise Exception(f"API Error during Prompt Interpretation: {prompt_interpretation}")
        analysis_reports["作文要求精准解读报告"] = prompt_interpretation
        logger.info(f"[{essay_id}] Agent 1: 完成.")
        time.sleep(random.uniform(0.2, 0.5)) # Keep small delay

        # 2. 文章主旨与核心内容提炼 (Essay Essence Extractor)
        logger.info(f"[{essay_id}] Agent 2: 提炼文章主旨与核心内容...")
        main_idea_prompt = generate_main_idea_summarizer_prompt(essay_data, essay_type, few_shot_examples=samples_data) # 传递样本
        main_idea_summary = call_llm_api(main_idea_prompt)
        if isinstance(main_idea_summary, str) and "API_ERROR" in main_idea_summary: raise Exception(f"API Error during Essay Essence Extraction: {main_idea_summary}")
        analysis_reports["作文主旨与核心内容分析报告"] = main_idea_summary
        logger.info(f"[{essay_id}] Agent 2: 完成.")
        time.sleep(random.uniform(0.2, 0.5))

        # 3. 内容与任务符合性校验 (Relevance Alignment Verifier)
        logger.info(f"[{essay_id}] Agent 3: 校验内容与任务符合性...")
        content_task_prompt = generate_content_task_analyzer_prompt(
            essay_data, prompt_interpretation, main_idea_summary, essay_type, few_shot_examples=samples_data # 传递样本
        )
        content_task_analysis = call_llm_api(content_task_prompt)
        if isinstance(content_task_analysis, str) and "API_ERROR" in content_task_analysis: raise Exception(f"API Error during Content-Task Alignment Verification: {content_task_analysis}")
        analysis_reports["作文内容与任务符合性校验报告"] = content_task_analysis
        logger.info(f"[{essay_id}] Agent 3: 完成.")
        time.sleep(random.uniform(0.2, 0.5))

        # 4. 素材主题契合度审核 (Material-Theme Alignment Auditor)
        logger.info(f"[{essay_id}] Agent 4: 审核素材主题契合度...")
        material_relevance_prompt = generate_material_relevance_checker_prompt(
            essay_data, prompt_interpretation, main_idea_summary, essay_type, few_shot_examples=samples_data # 传递样本
        )
        material_relevance_analysis = call_llm_api(material_relevance_prompt)
        if isinstance(material_relevance_analysis, str) and "API_ERROR" in material_relevance_analysis: raise Exception(f"API Error during Material-Theme Alignment Audit: {material_relevance_analysis}")
        analysis_reports["作文素材主题契合度审核报告"] = material_relevance_analysis
        logger.info(f"[{essay_id}] Agent 4: 完成.")
        time.sleep(random.uniform(0.2, 0.5))

        # 5. 最终切题度裁决 (Final Relevance Adjudicator) - This agent now gives the primary classification
        logger.info(f"[{essay_id}] Agent 5: 进行最终切题度裁决...")
        comprehensive_prompt = generate_comprehensive_reviewer_prompt(essay_data, analysis_reports, essay_type, few_shot_examples=samples_data) # 传递样本
        comprehensive_review = call_llm_api(comprehensive_prompt)
        if isinstance(comprehensive_review, str) and "API_ERROR" in comprehensive_review: raise Exception(f"API Error during Final Relevance Adjudication: {comprehensive_review}")
        analysis_reports["切题度最终裁决报告"] = comprehensive_review
        final_review_report = comprehensive_review # Store the report for potential arbitration

        # 6. 提取最终分类 (Extract classification directly from the Adjudicator's report)
        logger.info(f"[{essay_id}] Agent 6: 提取最终分类...")
        # 注意：此 Agent 的 Prompt 不接收 Few-Shot 示例
        final_class_prompt = generate_final_classification_prompt(essay_data, comprehensive_review)
        llm_agent_classification_raw = call_llm_api(final_class_prompt, retry_on_empty=False) # Don't retry if empty, handle below

        # Process Agent 6 result or fallback to regex
        if isinstance(llm_agent_classification_raw, str) and ("API_ERROR" in llm_agent_classification_raw or not llm_agent_classification_raw.strip()):
             logger.warning(f"[{essay_id}] Failed to extract final classification via Agent 6 (API Error or Empty: '{llm_agent_classification_raw[:50]}...'). Falling back to regex extraction on Adjudicator report.")
             llm_agent_classification = extract_classification(comprehensive_review) # Regex fallback
             if llm_agent_classification == "一般": # Check if regex also failed
                  logger.error(f"[{essay_id}] Fallback regex extraction also failed to find a specific class. Defaulting Adjudicator classification to '一般'. Report Snippet: {comprehensive_review[:200]}...")
        else:
             potential_class = llm_agent_classification_raw.strip()
             if potential_class in CLASSIFICATION_LABELS:
                 llm_agent_classification = potential_class
                 logger.debug(f"[{essay_id}] Agent 6 extracted classification: '{llm_agent_classification}'")
             else:
                 logger.warning(f"[{essay_id}] Agent 6 extracted an invalid classification '{potential_class}'. Falling back to regex extraction on Adjudicator report.")
                 llm_agent_classification = extract_classification(comprehensive_review) # Regex fallback

        final_classification = llm_agent_classification # Initially set final class from LLM chain
        logger.info(f"[{essay_id}] Agent 5/6: 完成. LLM Adjudicator 分类: '{llm_agent_classification}'.")
        time.sleep(random.uniform(0.2, 0.5))


        # --- BERT Auxiliary Evaluation & Arbitration ---
        if bert_model and tokenizer:
            logger.info(f"[{essay_id}] BERT: 进行辅助预测...")
            bert_classification = predict_with_bert(bert_model, tokenizer, essay_data)

            if bert_classification:
                logger.info(f"[{essay_id}] BERT: 预测分类为 '{bert_classification}'.")

                # Validate LLM classification before comparison
                if llm_agent_classification not in CLASSIFICATION_LABELS:
                     logger.warning(f"[{essay_id}] LLM Adjudicator classification '{llm_agent_classification}' is invalid. Cannot compare with BERT. Skipping arbitration. Using LLM result '{llm_agent_classification}'.")
                     # Keep the potentially invalid LLM classification for now, handle at the end
                     final_classification = llm_agent_classification

                elif llm_agent_classification != bert_classification:
                    logger.warning(f"[{essay_id}] LLM Adjudicator 分类 ('{llm_agent_classification}') 与 BERT 分类 ('{bert_classification}') 不一致，启动最终冲突裁决...")

                    # 格式化 Few-Shot 示例 for Arbiter
                    few_shot_prompt_arbiter = format_few_shot_examples(samples_data)

                    # --- Final Arbiter Agent ---
                    arbiter_system_prompt = """
你的角色: 教育部作文评审首席专家(30年资历)和国家级特级教师(主管作文教学)，职责是解决自动化评级系统（LLM Agent链与BERT模型）之间的切题度评级分歧。

你的任务:
1.  **绝对独立地、批判性地**审阅作文原文和题目要求。**这是你判断的唯一基石。**
2.  **知悉**下方提供的两个存在分歧的初步评级结果及其简要依据。**仅将它们视为参考信息，绝不能被其左右或束缚。**
3.  **严格对照**下方官方切题度等级标准，对作文的切题度做出**唯一、客观、精准**的最终裁决。
4.  **聚焦核心问题**: 作文是否完成了题目要求？中心思想是否明确、切题？材料是否有效支撑中心？是否存在硬伤（如严重偏题、未满足限制条件）？**优先判断是否属于“不合格”,"合格"或“其他三项”范畴。**

官方切题度等级标准 (**必须严格遵守，作为唯一评判依据**):
优秀(5分): 90%及以上完成题目核心任务要求，作文**高度契合**写作要求的所有方面。立意精准深刻，中心思想**极其**明确、突出。选材**完全**服务于中心，恰到好处，与主题高度统一、相得益彰，无任何偏离性内容，素材均与主题形成强关联。
较好(4分): 完成主要写作任务80%到90%之间，作文**充分满足**写作要求。立意准确，中心思想**很**明确、清晰。选材**紧密**围绕中心，较为恰当，与主题结合良好，主要素材有效支撑主题。
一般(3分): 基本完成任务70%到80%之间 ，作文**基本符合**写作要求。立意**大致准确**，中心思想**尚属**清晰，但可能不够突出或集中。选材**基本**能服务于中心，但可能存在部分关联性不强或不够典型的材料。整体表现**可以接受**。
合格(2分): 任务完成度60%到70%之间，作文**最低限度**满足了写作的基本要求，**未完全脱离**主题范围。但立意**不够精准**，中心思想**不明确**、**不突出**，需要费力寻找。选材与主题的**关联度低**，**不能有效支撑**中心。**存在明显偏题、跑题或对要求理解不到位的情况，中心思想需反复推敲才能关联题目**。
不合格(1分): 核心任务未完成或少于60%，作文**严重偏离**写作要求。立意**错误或极其模糊**。中心思想**完全不清晰**，无法把握，中心思想与题目要求相悖。选材与主题**毫不相干**或**完全相悖**。**完全未理解题目要求，或存在根本性跑题**。

核心裁决原则:
1.  **任务完成度优先**: 文章是否完成了题目的核心任务？是否遵守了所有明确的限制条件（如文体、字数、特定内容要求等）？**任何核心任务未完成或限制条件未遵守，原则上不能评为“一般”及以上。**
2.  **中心思想是关键**: 中心思想是否明确？是否紧扣题目要求？**中心思想模糊、偏离是判定为“合格”或“不合格”的关键指标。**
3.  **内容与材料的相关性**: 选择的内容和材料是否**直接且有效**地支撑中心思想？**材料空泛、堆砌、与主题关联弱是拉低评分的重要因素。**
4.  **严格把控界限**: **特别注意“合格”与“不合格”、“一般”与“合格”的界限。** 对于可能处于边界的作文，**必须从严审视**其是否达到了更高一级标准的所有基本要求。**绝不允许因为个别语句尚可或结构完整等非切题因素而将“合格”作文提升至“一般”，或将“不合格”作文评为“合格”。**
5.  **综合判断，拒绝“平均主义”**: 依据作文整体表现与标准的符合度进行判断，而不是简单地对专家报告进行平均。要敢于推翻明显不合理的初步意见。
{few_shot_prompt}
输入参考:
LLM Agent链裁决: {llm_class} (依据见: {llm_report_snippet})
BERT模型预测: {bert_class} (无详细理由)

**最终输出要求:**
**直接输出最终裁决的等级词语，且仅输出该词语。**
**必须是以下五个词语之一： 优秀 | 较好 | 一般 | 合格 | 不合格**
**禁止包含任何其他文字、解释、说明、标点符号、或空格。**
""".format(
                        few_shot_prompt=few_shot_prompt_arbiter, # 插入 Few-Shot 示例
                        llm_class=llm_agent_classification,
                        bert_class=bert_classification,
                        llm_report_snippet=f"LLM报告片段: {final_review_report[:600]}..." if final_review_report else "LLM报告不可用"
                    )

                    arbiter_user_prompt = f"""
请对以下作文的切题度进行最终冲突裁决：

年级： {essay_data.get('grade', '未知')}年级
作文要求原文： {essay_data.get('requirement', '无')}
作文标题： {essay_data.get('title', '无')}
作文原文：
{essay_data.get('content', '无')}

冲突评级:
LLM Agent链裁决: {llm_agent_classification}
BERT 模型预测: {bert_classification}

请严格按要求，给出最终裁决（只回答一个词：优秀/较好/一般/合格/不合格）。
"""
                    arbiter_messages = [
                        Message(role="system", content=arbiter_system_prompt),
                        Message(role="user", content=arbiter_user_prompt)
                    ]

                    logger.info(f"[{essay_id}] 调用最终冲突裁决 Agent...")
                    arbiter_response_raw = call_llm_api(arbiter_messages, retry_on_empty=False)

                    # Process arbiter response
                    if isinstance(arbiter_response_raw, str) and ("API_ERROR" in arbiter_response_raw or not arbiter_response_raw.strip()):
                        logger.error(f"[{essay_id}] 最终冲突裁决API调用失败或返回空. 保留 LLM Adjudicator 结果 '{llm_agent_classification}'. Response: {arbiter_response_raw[:50]}...")
                        final_classification = llm_agent_classification # Stick with original LLM result on arbiter failure
                    else:
                        potential_arbiter_class = arbiter_response_raw.strip()
                        if potential_arbiter_class in CLASSIFICATION_LABELS:
                             logger.info(f"[{essay_id}] 最终冲突裁决分类为: '{potential_arbiter_class}'. 使用此结果。")
                             final_classification = potential_arbiter_class # Update final classification
                        else:
                             logger.warning(f"[{essay_id}] 最终冲突裁决未能提取有效分类 ('{potential_arbiter_class}'). 保留 LLM Adjudicator 结果 '{llm_agent_classification}'.")
                             final_classification = llm_agent_classification # Stick with original LLM result if arbiter output invalid

                else:
                    # LLM and BERT agree
                    logger.info(f"[{essay_id}] LLM Adjudicator 与 BERT 分类一致 ('{llm_agent_classification}')，无需冲突裁决. 使用此结果。")
                    final_classification = llm_agent_classification # Confirmed classification

            else:
                 # BERT prediction failed
                 logger.info(f"[{essay_id}] BERT: 未能生成预测结果. 使用 LLM Adjudicator 结果 '{llm_agent_classification}'.")
                 final_classification = llm_agent_classification # Stick with LLM result
        else:
            # BERT model/tokenizer not available
            logger.info(f"[{essay_id}] BERT: 模型或Tokenizer不可用，跳过辅助评估. 使用 LLM Adjudicator 结果 '{llm_agent_classification}'.")
            final_classification = llm_agent_classification # Stick with LLM result

    except Exception as e:
        logger.error(f"[{essay_id}] 评估过程中发生严重错误: {str(e)}", exc_info=True)
        error_occurred = True
        error_message = f"{type(e).__name__}: {str(e)}"
        final_classification = "一般" # Default on error
        llm_agent_classification = "错误" # Mark LLM result as error

    # --- Result Aggregation & Final Validation ---
    final_classification_safe = final_classification # Start with the determined classification

    if error_occurred:
        final_classification_safe = "一般" # Override to default if a major error happened
    elif final_classification not in CLASSIFICATION_LABELS:
         logger.warning(f"[{essay_id}] 最终确定的分类 '{final_classification}' 不是有效标签. 将强制设为 '一般'.")
         final_classification_safe = "一般" # Force to default if invalid at the very end

    result = {
        "id": essay_id,
        "classification": final_classification_safe, # Use the safe, validated classification
        "llm_adjudicator_classification": llm_agent_classification if not error_occurred else "错误",
        "bert_classification": bert_classification, # Can be None if BERT failed/skipped
        "essay_type": essay_type, # 添加作文类型信息
        "analysis_reports": analysis_reports,
        "error": error_message if error_occurred else None
    }

    end_time = time.time()
    duration = end_time - start_time
    logger.info(f"[{essay_id}] 评估完成. 最终分类: '{result['classification']}'. 耗时: {duration:.2f} 秒.")
    return result


def evaluate_essays(essays_data: List[Dict[str, Any]], samples_data: List[Dict[str, Any]] = None):
    """顺序评估多篇作文，并传入 Few-Shot 样本"""
    results_summary = []  # 存储最终提交格式的结果 {id: ..., classification: ...}
    detailed_results = []  # 存储包含所有分析的详细结果

    # 加载已处理的作文记录
    processed_essays_cache = load_processed_essays()
    essays_to_process = []
    processed_ids = set()  # 记录已在缓存中或待处理的ID，防止重复

    # 筛选未处理的作文并处理缓存
    for essay in essays_data:
        # Ensure essay is a dictionary and has an ID
        if not isinstance(essay, dict):
            logger.warning(f"Skipping non-dictionary item in input data: {type(essay)}")
            continue
        essay_id_val = essay.get("id")
        if essay_id_val is None:
            logger.warning(f"Skipping essay with missing ID: {essay}")
            continue
        essay_id_str = str(essay_id_val)  # Ensure ID is string

        if essay_id_str in processed_ids:
             logger.warning(f"Skipping essay with duplicate ID: {essay_id_str}")
             continue

        if essay_id_str in processed_essays_cache:
            logger.info(f"作文 ID: {essay_id_str} 在缓存中找到，加载缓存结果。")
            cached_result = processed_essays_cache[essay_id_str]
            # 确保缓存结果包含所需字段
            summary_classification = cached_result.get("classification", "一般")
            # Validate classification from cache
            if summary_classification not in CLASSIFICATION_LABELS:
                 logger.warning(f"ID {essay_id_str}: 缓存中的分类 '{summary_classification}' 无效，将使用 '一般'。")
                 summary_classification = "一般"
                 cached_result["classification"] = "一般"  # Fix in loaded cache data for consistency

            # Important: Append the original ID (string) here. Reformatting happens later in main.
            results_summary.append({
                "id": cached_result.get("id", essay_id_str), # Keep original string ID here
                "classification": summary_classification
            })

            # 为详细结果添加原始essay数据 (确保分类也被更新)
            original_essay = next((e for e in essays_data if str(e.get("id")) == essay_id_str), None)
            if original_essay:
                 detailed_result = {**original_essay, "classification": summary_classification}
            else:
                 # Fallback if original essay data somehow missing for this ID
                 detailed_result = {**cached_result, "classification": summary_classification}
            detailed_results.append(detailed_result)
            processed_ids.add(essay_id_str)
        else:
            essays_to_process.append(essay)
            processed_ids.add(essay_id_str)

    processed_count_cache = len(processed_ids) - len(essays_to_process)
    total_valid_essays = len(processed_ids) # Should be total unique IDs found
    logger.info(f"共 {len(essays_data)} 条输入记录，识别出 {total_valid_essays} 个有效唯一ID。")
    logger.info(f"{processed_count_cache} 篇已在缓存中，需要处理 {len(essays_to_process)} 篇新作文。")


    if not essays_to_process:
        logger.info("所有需要的作文均已处理，无需执行新的评估。")
        # 确保返回结果按原始ID顺序（如果原始数据有ID的话）
        original_id_order = [str(e.get("id")) for e in essays_data if isinstance(e, dict) and e.get("id") is not None]
        if original_id_order:
            # Create a mapping for quick lookup
            order_map = {id_val: i for i, id_val in enumerate(original_id_order)}
            # Ensure results are dicts before sorting
            results_summary = [r for r in results_summary if isinstance(r, dict)]
            detailed_results = [r for r in detailed_results if isinstance(r, dict)]
            # Sort using the original index derived from the input data order
            results_summary.sort(key=lambda x: order_map.get(str(x.get("id")), float('inf')))
            detailed_results.sort(key=lambda x: order_map.get(str(x.get("id")), float('inf')))
        return results_summary, detailed_results

    # 加载或训练BERT模型 (使用所有有效数据，包括缓存的和待处理的，以获得更好的模型)
    all_valid_data_for_bert = [e for e in essays_data if isinstance(e, dict) and str(e.get("id")) in processed_ids]
    bert_model, tokenizer = load_or_train_bert_model_unsupervised(all_valid_data_for_bert)


    # 顺序处理每篇作文，添加进度条
    progress_bar = tqdm(essays_to_process, desc="评估作文进度", position=0) # Iterate directly

    for i, essay_data in enumerate(progress_bar): # Use enumerate on the progress bar iterator
        current_index = processed_count_cache + i # Overall index including cached items
        logger.info(f"处理作文 {current_index + 1}/{total_valid_essays} (新处理第 {i+1} 篇), ID: {essay_data.get('id', 'unknown')}")
        try:
            # 传递 samples_data
            result = evaluate_essay_with_agents(essay_data, samples_data=samples_data, bert_model=bert_model, tokenizer=tokenizer)

            if result and isinstance(result, dict) and "id" in result and result.get("id") != "unknown":
                 essay_id_str = str(result["id"])
                 # Update cache with the new/updated result
                 processed_essays_cache[essay_id_str] = result
                 # Save cache immediately after processing each essay
                 save_processed_essays(processed_essays_cache)

                 # Get the final validated classification
                 summary_classification = result.get("classification", "一般") # Already validated in evaluate_essay_with_agents

                 # Prepare summary result for output file formatting later
                 summary = {
                    "id": result.get("id"), # Keep original ID for now
                    "classification": summary_classification
                 }
                 results_summary.append(summary)

                 # Prepare detailed result (merge original data with final classification)
                 # The original essay_data is the one being processed
                 detailed_result = {**essay_data, "classification": summary_classification}
                 detailed_results.append(detailed_result)

            elif result:
                 logger.warning(f"评估作文 ID {essay_data.get('id', 'unknown')} 返回结果缺少有效ID或不是字典: {result}")
            else:
                 logger.error(f"评估作文 ID {essay_data.get('id', 'unknown')} 任务返回了 None 或无效结果.")
                 # Optionally, add a placeholder to maintain count/order if needed
                 results_summary.append({"id": essay_data.get('id'), "classification": "错误"})
                 detailed_results.append({**essay_data, "classification": "错误"})

        except Exception as e:
             logger.error(f"处理作文 ID {essay_data.get('id', 'unknown')} 时发生未捕获错误: {e}", exc_info=True)
             # Add error placeholder
             results_summary.append({"id": essay_data.get('id'), "classification": "错误"})
             detailed_results.append({**essay_data, "classification": "错误"})

        # --- 添加内存清理机制 ---
        # 每处理20篇新作文，尝试进行一次内存清理
        if (i + 1) % 20 == 0:
            logger.info(f"处理完 {i + 1} 篇新作文（总计 {current_index + 1} 篇），尝试清理内存...")
            try:
                # 手动触发垃圾回收
                collected_count = gc.collect()
                logger.debug(f"gc.collect() called. Collected {collected_count} objects.")

                # 如果使用了GPU，清理PyTorch的CUDA缓存
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    logger.debug("torch.cuda.empty_cache() called.")
                logger.info("内存清理尝试完成。")
            except Exception as e:
                logger.error(f"内存清理过程中发生错误: {e}", exc_info=True)
        # --- 内存清理机制结束 ---


    # progress_bar is automatically closed when the loop finishes

    # 对最终结果按原始 ID 顺序排序 (Important after combining cached and new results)
    original_id_order = [str(e.get("id")) for e in essays_data if isinstance(e, dict) and e.get("id") is not None]
    if original_id_order:
        order_map = {id_val: i for i, id_val in enumerate(original_id_order)}
        results_summary = [r for r in results_summary if isinstance(r, dict)] # Filter out Nones if any added
        detailed_results = [r for r in detailed_results if isinstance(r, dict)]
        # Sort based on the index in the original input data
        results_summary.sort(key=lambda x: order_map.get(str(x.get("id")), float('inf')))
        detailed_results.sort(key=lambda x: order_map.get(str(x.get("id")), float('inf')))

    return results_summary, detailed_results


def load_data(file_path: str) -> List[Dict[str, Any]]:
    """加载JSON数据文件"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

            # Basic validation: check if it's a list
            if not isinstance(data, list):
                logger.error(f"错误: 数据文件 {file_path} 的顶层不是一个列表。")
                return []
            logger.info(f"成功加载数据: {len(data)} 条记录从 {file_path}")
            return data
    except FileNotFoundError:
        logger.error(f"错误: 数据文件未找到 {file_path}")
        return []
    except json.JSONDecodeError as e:
        logger.error(f"错误: 无法解析JSON文件 {file_path}. Error: {e}")
        return []
    except Exception as e:
        logger.error(f"加载数据时发生未知错误: {e}")
        return []

# 新增：加载样本文件
def load_samples(file_path: str) -> List[Dict[str, Any]]:
    """加载 JSON 格式的 Few-Shot 样本文件"""
    logger.info(f"尝试加载 Few-Shot 样本文件: {file_path}")
    samples = load_data(file_path) # 复用 load_data 函数
    if samples:
        logger.info(f"成功加载 {len(samples)} 条 Few-Shot 样本。")
    else:
        logger.warning(f"未能加载 Few-Shot 样本文件或文件为空: {file_path}。评估将不使用 Few-Shot 示例。")
    return samples if samples else [] # 确保返回列表


def save_results(results: List[Dict[str, Any]], output_file: str) -> None:
    """保存结果到JSON文件 (使用原子写入)"""
    try:
        # Ensure directory exists
        output_dir = os.path.dirname(output_file)
        if output_dir and not os.path.exists(output_dir):
             os.makedirs(output_dir)

        temp_file_path = output_file + ".tmp"
        with open(temp_file_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        os.replace(temp_file_path, output_file)  # Atomic replace if possible across filesystems
        logger.info(f"结果成功保存到: {output_file}")
    except Exception as e:
        logger.error(f"错误: 无法写入结果文件 {output_file}. Error: {e}")
        # Clean up temp file if it still exists after error
        if os.path.exists(temp_file_path):
            try:
                os.remove(temp_file_path)
            except OSError as remove_error:
                logger.error(f"Error removing temporary save file {temp_file_path}: {remove_error}")

def add_ids_to_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """如果数据中没有 'id' 字段，则添加从0开始的ID，并确保ID是字符串，处理重复ID"""
    if not data:
        return data

    missing_id_count = 0
    valid_ids = set()
    processed_data = []
    original_indices = {}  # Store original index for sorting later if needed
    initial_dict_count = 0
    duplicate_ids_found = []

    for i, item in enumerate(data):
        if not isinstance(item, dict):
            logger.warning(f"Skipping non-dictionary item at index {i}: {type(item)}")
            continue
        initial_dict_count += 1

        item_id = item.get("id")
        if item_id is None:
            # Generate a unique ID based on index if missing
            item_id = f"auto_id_{i}"
            item["id"] = item_id # Add the ID back to the item
            missing_id_count += 1
            logger.debug(f"Assigned auto ID '{item_id}' to item at index {i}")
        else:
            # Ensure ID is string type for consistency
            item_id = str(item_id)
            item["id"] = item_id # Ensure the string version is stored

        # Check for duplicates
        if item_id in valid_ids:
            logger.warning(f"发现重复 ID: '{item_id}' 在索引 {i}。将跳过此重复条目。Previous index: {original_indices.get(item_id, 'N/A')}")
            duplicate_ids_found.append(item_id)
            continue  # Skip adding this duplicate item

        # Store valid item and its original index mapping
        valid_ids.add(item_id)
        original_indices[item_id] = i
        processed_data.append(item)

    if missing_id_count > 0:
        logger.info(f"{missing_id_count} 条记录缺少 'id' 字段，已自动添加 (格式: auto_id_INDEX)。")
    duplicates_removed = initial_dict_count - len(processed_data)
    if duplicates_removed > 0:
        logger.warning(f"由于重复ID，原始数据中的 {duplicates_removed} 条有效记录已被移除。重复的ID包括: {list(set(duplicate_ids_found))}")
    if not processed_data:
        logger.warning("处理ID后数据列表为空。")

    return processed_data


def main():
    overall_start_time = time.time()

    logger.info(f"加载测试数据从: {input_file}")
    test_data = load_data(input_file)

    if not test_data:
        logger.error("无法加载测试数据或数据为空，程序终止。")
        return

    # Add/Validate IDs and remove duplicates first
    test_data = add_ids_to_data(test_data)
    if not test_data:
        logger.error("处理测试数据ID后数据为空（可能所有条目都无效或重复），程序终止。")
        return

    # 加载 Few-Shot 样本数据
    samples_data = load_samples(samples_file) # samples_data can be an empty list if loading fails

    # Ensure data is sorted according to original order before evaluation if needed
    # This is implicitly handled by processing order if add_ids_to_data preserves order,
    # and the final sort in evaluate_essays handles combined results.

    logger.info("开始顺序评估所有作文...")
    # evaluate_essays now returns the summary results sorted by original input order
    # Pass samples_data to evaluate_essays
    results_summary, detailed_results = evaluate_essays(test_data, samples_data=samples_data)


    # --- START ID REFORMATTING FOR OUTPUT FILE ---
    # Reformat the final *sorted* summary results to have sequential integer IDs starting from 0
    final_formatted_results = []
    if results_summary: # Check if the sorted summary list is not empty
         logger.info(f"重新格式化最终结果的 ID 为从 0 开始的整数，用于输出文件 {output_file}...")
         for index, original_result in enumerate(results_summary):
             # Basic check if the item from summary list is as expected
             if isinstance(original_result, dict) and "classification" in original_result:
                 # Create the new entry with sequential integer ID
                 new_entry = {
                     "id": index, # Use the loop index as the new ID
                     "classification": original_result.get("classification", "一般") # Get classification, default if missing
                 }
                 final_formatted_results.append(new_entry)
             else:
                 # Log if an unexpected item is found in the summary list
                 logger.warning(f"在最终结果列表中发现无效条目 (索引 {index})，跳过格式化: {original_result}")
    # --- END ID REFORMATTING FOR OUTPUT FILE ---

    # Save the reformatted results
    if final_formatted_results:
        logger.info(f"保存最终格式化分类结果 ({len(final_formatted_results)} 条) 到: {output_file}")
        save_results(final_formatted_results, output_file)
    elif results_summary: # Log if formatting failed but summary existed
         logger.warning("格式化最终结果列表失败或为空，但原始摘要结果存在。将不保存主输出文件。")
    else:
        logger.warning("没有生成任何最终分类结果，不保存文件。")

    # Optionally save detailed results (with original IDs) to a different file
    # detailed_output_file = os.path.join(data_root, "detailed_evaluation_results.json")
    # if detailed_results:
    #     logger.info(f"保存详细评估结果 ({len(detailed_results)} 条) 到: {detailed_output_file}")
    #     save_results(detailed_results, detailed_output_file)


    overall_end_time = time.time()
    total_duration = overall_end_time - overall_start_time
    processed_count = len(final_formatted_results) if final_formatted_results else 0 # Count based on final formatted results
    logger.info(f"所有评估任务完成。共处理 {processed_count} 篇有效作文。总耗时: {total_duration:.2f} 秒.")
    if processed_count > 0:
        avg_time = total_duration / processed_count
        logger.info(f"平均每篇作文评估耗时: {avg_time:.2f} 秒.")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.critical(f"主程序运行时发生严重错误: {e}", exc_info=True)