# 中文新闻文本标题分类

## 环境准备、路径与超参数总配置 (Global Configuration: Environment, Paths & Hyperparameters)

In [None]:
# ==============================================================================
# 0. 导入库 (Import Libraries)
# ==============================================================================
import os
import csv # 用于读写CSV文件
import json
import time
import torch
import datasets
import logging
import evaluate
import subprocess
import transformers
import pandas as pd
import numpy as np
from PIL import Image 
from tqdm.auto import tqdm
from datasets import Dataset 
from torch.optim import AdamW
from datetime import datetime
from functools import partial 
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding
)

# ==============================================================================
# 1. 基础环境与日志配置 (Basic Environment & Logging Setup)
# ==============================================================================

# --- 1.1 网络代理设置 (Network Proxy Setup) ---
# autodl代理加速配置
# try:
#     result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True, check=True)
#     output = result.stdout
#     for line in output.splitlines():
#         if '=' in line:
#             var, value = line.split('=', 1)
#             os.environ[var] = value
#     print("代理环境变量设置成功。") # 使用 print 因为 log 可能还未配置
# except subprocess.CalledProcessError as e:
#     print(f"设置代理环境变量失败: {e}")
# except FileNotFoundError:
#     print("无法执行 bash 命令，可能不在 Linux 环境或未安装 bash。")

# --- 1.2 日志配置 (Logging Configuration) ---
LOG_DIR = "./files/logs"
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

# TensorBoard Writer - 保留时间戳以区分不同的运行
TENSORBOARD_RUN_NAME = datetime.now().strftime("%Y%m%d-%H%M%S") 
TENSORBOARD_LOG_DIR = os.path.join(LOG_DIR, "tensorboard_runs", TENSORBOARD_RUN_NAME)
writer = SummaryWriter(TENSORBOARD_LOG_DIR)

# 文件日志 
LOG_FILE_NAME = f"training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
LOG_FILE_PATH = os.path.join(LOG_DIR, LOG_FILE_NAME)

# 清除已存在的handlers，避免重复日志
if logging.root.handlers:
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
        handler.close()

if not logging.getLogger().hasHandlers(): # 只在没有 handler 的时候配置
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # 添加了 %(name)s
        handlers=[
            logging.FileHandler(LOG_FILE_PATH, mode='w', encoding='utf-8'), # mode='w' 每次覆盖
            logging.StreamHandler() # 同时输出到控制台
        ]
    )
log = logging.getLogger(__name__) # 获取一个logger实例

log.info("环境准备与日志配置完成。")
log.info(f"文本日志将保存到: {LOG_FILE_PATH}")
log.info(f"TensorBoard 日志将保存到: {TENSORBOARD_LOG_DIR}")


# --- 1.3 版本与设备信息 (Version & Device Information) ---
log.info(f"PyTorch 版本: {torch.__version__}")
log.info(f"Transformers 版本: {transformers.__version__}")
log.info(f"Datasets 版本: {datasets.__version__}") # 需要 import datasets
log.info(f"Evaluate 版本: {evaluate.__version__}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"将使用设备: {device}")
if torch.cuda.is_available():
    log.info(f"CUDA 可用. GPU数量: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        log.info(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    log.info(f"当前默认GPU: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    log.info("CUDA 不可用, 将使用 CPU.")

# ==============================================================================
# 2. 路径与文件名配置 (Paths & Filenames Configuration)
# ==============================================================================
# 确保主文件目录存在
if not os.path.exists("./files"):
    os.makedirs("./files")

# --- 2.1 原始数据路径 (Raw Data Paths) ---
RAW_DATA_DIR = "./files/raw_data"
if not os.path.exists(RAW_DATA_DIR): 
    os.makedirs(RAW_DATA_DIR)
    
TRAIN_FILE_NAME = "train.txt"
DEV_FILE_NAME = "dev.txt"
TEST_FILE_NAME = "test.txt"
TRAIN_ADD_FILE_NAME = "train_add.txt" # 用于数据增强的文件名

# 完整路径
TRAIN_FILE_PATH = os.path.join(RAW_DATA_DIR, TRAIN_FILE_NAME)
DEV_FILE_PATH = os.path.join(RAW_DATA_DIR, DEV_FILE_NAME)
TEST_FILE_PATH = os.path.join(RAW_DATA_DIR, TEST_FILE_NAME)
AUG_TRAIN_FILE_PATH = os.path.join(RAW_DATA_DIR, TRAIN_ADD_FILE_NAME) # 增强数据文件路径

# --- 2.2 预处理数据缓存路径 (Processed Data Cache Paths) ---
CACHE_DATA_DIR = "./files/processed_data_cache"
if not os.path.exists(CACHE_DATA_DIR): 
    os.makedirs(CACHE_DATA_DIR)
    
CACHE_TRAIN_PATH = os.path.join(CACHE_DATA_DIR, "train_dataset_cache")
CACHE_DEV_PATH = os.path.join(CACHE_DATA_DIR, "dev_dataset_cache")
CACHE_TEST_PATH = os.path.join(CACHE_DATA_DIR, "test_dataset_cache")
CACHE_AUG_TRAIN_PATH = os.path.join(CACHE_DATA_DIR, "aug_train_dataset_cache") # 缓存增强数据

# --- 2.3 模型保存路径 (Model Saving Paths) ---
SAVED_MODELS_DIR = "./files/saved_models"
if not os.path.exists(SAVED_MODELS_DIR): 
    os.makedirs(SAVED_MODELS_DIR)
    
BEST_MODEL_NAME = "best_roberta_model.pt" # 初次微调的最佳模型
BEST_MODEL_PATH = os.path.join(SAVED_MODELS_DIR, BEST_MODEL_NAME)

BEST_AUG_MODEL_NAME = "best_augmented_model.pt" # 二次微调后的最佳模型
BEST_AUG_MODEL_PATH = os.path.join(SAVED_MODELS_DIR, BEST_AUG_MODEL_NAME)

# --- 2.4 最终预测结果输出路径 (Final Prediction Results Output Paths) ---
FINAL_RESULTS_DIR = "./files/results"
if not os.path.exists(FINAL_RESULTS_DIR): 
    os.makedirs(FINAL_RESULTS_DIR)

FINAL_TOP1_JSON_FILE_NAME = "final_test_top1_predictions.json" 
FINAL_RESULT_TXT_FILE_NAME = "result.txt"  

FINAL_TOP1_JSON_PATH = os.path.join(FINAL_RESULTS_DIR, FINAL_TOP1_JSON_FILE_NAME)
FINAL_RESULT_TXT_PATH = os.path.join(FINAL_RESULTS_DIR, FINAL_RESULT_TXT_FILE_NAME)

# --- 2.5 初步分析结果输出路径 (Initial Analysis Results Output Paths - after 1st fine-tuning) ---
ANALYSIS_DIR = "./files/analysis_results" 
if not os.path.exists(ANALYSIS_DIR): 
    os.makedirs(ANALYSIS_DIR)
    
TRAIN_ADD_FILE_PATH = AUG_TRAIN_FILE_PATH

ALL_TEST_ANALYSIS_OUTPUT_DIR = os.path.join(ANALYSIS_DIR, "all_test_analysis")
if not os.path.exists(ALL_TEST_ANALYSIS_OUTPUT_DIR): 
    os.makedirs(ALL_TEST_ANALYSIS_OUTPUT_DIR)
    
TOP1_JSON_FILE_NAME = "test_top1_predictions.json" 
PROB_DIST_PLOT_FILE_NAME = "test_prob_distribution.png" 
PROB_DIST_CSV_FILE_NAME = "test_prob_distribution_data.csv" 

TOP1_JSON_PATH = os.path.join(ALL_TEST_ANALYSIS_OUTPUT_DIR, TOP1_JSON_FILE_NAME)
PROB_DIST_PLOT_PATH = os.path.join(ALL_TEST_ANALYSIS_OUTPUT_DIR, PROB_DIST_PLOT_FILE_NAME)
PROB_DIST_CSV_PATH = os.path.join(ALL_TEST_ANALYSIS_OUTPUT_DIR, PROB_DIST_CSV_FILE_NAME)

# --- 2.6 最终预测分析的图表/CSV输出路径 (Final Prediction Analysis Plot/CSV Output Paths) ---
FINAL_ANALYSIS_SUBDIR_NAME = "prediction_analysis_charts_csv" # 子目录名
FINAL_ANALYSIS_OUTPUT_PATH = os.path.join(FINAL_RESULTS_DIR, FINAL_ANALYSIS_SUBDIR_NAME)
if not os.path.exists(FINAL_ANALYSIS_OUTPUT_PATH): 
    os.makedirs(FINAL_ANALYSIS_OUTPUT_PATH)

FINAL_PROB_DIST_PLOT_FILE_NAME = "final_test_prob_distribution.png"  
FINAL_PROB_DIST_CSV_FILE_NAME = "final_test_prob_distribution_data.csv" 

FINAL_PROB_DIST_PLOT_PATH = os.path.join(FINAL_ANALYSIS_OUTPUT_PATH, FINAL_PROB_DIST_PLOT_FILE_NAME)
FINAL_PROB_DIST_CSV_PATH = os.path.join(FINAL_ANALYSIS_OUTPUT_PATH, FINAL_PROB_DIST_CSV_FILE_NAME)


# ==============================================================================
# 3. 模型与 Tokenizer 配置 (Model & Tokenizer Configuration)
# ==============================================================================
PRETRAINED_MODEL_NAME = 'hfl/chinese-roberta-wwm-ext-large'
# 优先使用本地缓存。如果 LOCAL_MODEL_PATH 为 None、空字符串或无效路径，
# AutoTokenizer/AutoModel.from_pretrained 会自动尝试从 PRETRAINED_MODEL_NAME 下载或使用Hub缓存。
LOCAL_MODEL_PATH = '/root/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext-large/snapshots/a25cc9e05974bd9687e528edd516f2cfdb3f5db9'
# 建议：可以检查 LOCAL_MODEL_PATH 是否有效，如果无效则设为 None
if not (LOCAL_MODEL_PATH and os.path.exists(LOCAL_MODEL_PATH)):
    log.warning(f"配置的本地模型路径 '{LOCAL_MODEL_PATH}' 无效或不存在。将依赖HuggingFace Hub加载 '{PRETRAINED_MODEL_NAME}'。")
    LOCAL_MODEL_PATH = None # 设为None，让transformers库自动处理

# ==============================================================================
# 4. 训练与数据处理超参数 (Training & Data Processing Hyperparameters)
# ==============================================================================

# --- 4.1 数据处理 (Data Processing) ---
MAX_SEQ_LENGTH = 28  # 文本最大序列长度
NUM_PROC_FOR_MAP = 12 # Dataset.map() 使用的并行进程数

# --- 4.2 DataLoader ---
BATCH_SIZE = 256
#NUM_WORKERS = 6  # DataLoader 的 num_workers。Linux上可尝试增加, Windows或Jupyter中建议为0或较小值
NUM_WORKERS = 0

# --- 4.3 训练过程 (Training Process) ---
# 初次微调的超参数
LEARNING_RATE = 2e-5
NUM_EPOCHS = 6
WARMUP_PROPORTION = 0.1
WEIGHT_DECAY = 0.01
EVAL_STRATEGY = "steps"             # "steps" 或 "epoch"
EVAL_FREQUENCY_FRAC_EPOCH = 0.2     # 如果 EVAL_STRATEGY="steps", 每训练X比例的epoch数据后验证
EARLY_STOPPING_ENABLED = True       # 是否启用早停 (初次微调)
EARLY_STOPPING_PATIENCE = 3        # 验证指标连续多少次评估没有改善后停止 (这里的“次”对应评估的次数)
EARLY_STOPPING_METRIC = 'val_loss'  # 早停基于的指标: 'val_loss', 'val_accuracy', 'val_f1_weighted'
EARLY_STOPPING_MIN_DELTA = 0.001    # 指标改善的最小阈值
# EARLY_STOPPING_MODE 会根据 EARLY_STOPPING_METRIC 自动推断 (含'loss'为'min', 否则'max')

# 二次微调的超参数
AUG_LEARNING_RATE = 2e-5
AUG_NUM_EPOCHS = 6
AUG_WARMUP_PROPORTION = 0.1
AUG_WEIGHT_DECAY = 0.01
# 二次微调的 EVAL_STRATEGY 和 EVAL_FREQUENCY_FRAC_EPOCH 将复用初次微调的设置。
# 如果需要不同，可以定义 AUG_EVAL_STRATEGY 和 AUG_EVAL_FREQUENCY_FRAC_EPOCH。
AUG_EVAL_STRATEGY = "steps"
AUG_EVAL_FREQUENCY_FRAC_EPOCH = 1
# --- 4.4 分析参数 (Analysis Parameters) ---
ANALYSIS_TOP_P_THRESHOLD = 0.9       # 用于单元格19的 top-p 分析的阈值 (如果该单元格仍在使用)
PROB_DIST_PLOT_STEP = 0.001        # 概率分布图的x轴分度值
TOP_PERCENT_FOR_AUGMENTATION = 0.80  # 用于数据增强的最高概率样本比例

log.info("所有配置参数定义完毕。")
log.info(f"文本日志文件: {LOG_FILE_PATH}")
log.info(f"TensorBoard 日志目录: {TENSORBOARD_LOG_DIR}")
log.info(f"原始数据目录: {RAW_DATA_DIR}")
log.info(f"  增强数据文件: {AUG_TRAIN_FILE_PATH}")
log.info(f"缓存数据目录: {CACHE_DATA_DIR}")
log.info(f"模型保存目录: {SAVED_MODELS_DIR}")
log.info(f"  初次最佳模型: {BEST_MODEL_PATH}")
log.info(f"  二次最佳模型: {BEST_AUG_MODEL_PATH}")
log.info(f"最终预测结果目录: {FINAL_RESULTS_DIR}")
log.info(f"  最终预测JSON: {FINAL_TOP1_JSON_PATH}")
log.info(f"  最终预测TXT: {FINAL_RESULT_TXT_PATH}")
log.info(f"初步分析结果目录: {ANALYSIS_DIR}")
log.info(f"  初步分析JSON: {TOP1_JSON_PATH}") # (示例)
log.info(f"最终预测分析图表/CSV目录: {FINAL_ANALYSIS_OUTPUT_PATH}")
log.info(f"  最终分析图: {FINAL_PROB_DIST_PLOT_PATH}") # (示例)

# ==============================================================================
# 单元格 3 结束
# ==============================================================================

2025-05-18 16:52:41,917 - __main__ - INFO - 环境准备与日志配置完成。
2025-05-18 16:52:41,918 - __main__ - INFO - 文本日志将保存到: ./files/logs/training_log_20250518_165241.log
2025-05-18 16:52:41,919 - __main__ - INFO - TensorBoard 日志将保存到: ./files/logs/tensorboard_runs/20250518-165241
2025-05-18 16:52:41,919 - __main__ - INFO - PyTorch 版本: 2.5.1+cu124
2025-05-18 16:52:41,919 - __main__ - INFO - Transformers 版本: 4.51.3
2025-05-18 16:52:41,920 - __main__ - INFO - Datasets 版本: 3.3.2
2025-05-18 16:52:41,920 - __main__ - INFO - Evaluate 版本: 0.4.3
2025-05-18 16:52:41,943 - __main__ - INFO - 将使用设备: cuda
2025-05-18 16:52:41,960 - __main__ - INFO - CUDA 可用. GPU数量: 1
2025-05-18 16:52:41,963 - __main__ - INFO -   GPU 0: NVIDIA GeForce RTX 3090
2025-05-18 16:52:41,963 - __main__ - INFO - 当前默认GPU: NVIDIA GeForce RTX 3090
2025-05-18 16:52:41,965 - __main__ - INFO - 所有配置参数定义完毕。
2025-05-18 16:52:41,966 - __main__ - INFO - 文本日志文件: ./files/logs/training_log_20250518_165241.log
2025-05-18 16:52:41,966 - __main__ - INFO - T

## 数据加载与标签映射 (Data Loading & Label Mapping)

In [2]:
# ==============================================================================

log.info("="*30 + " 开始数据加载与标签映射 " + "="*30)

# --- 5.1 数据加载函数定义 (Data Loading Function Definition) ---
def load_data_from_file(file_path: str, is_test_set: bool = False, delimiter: str = '\t') -> list:
    """
    从指定的文本文件加载数据到字典列表

    参数:
        file_path (str): 数据文件的完整路径
        is_test_set (bool): 指示是否为测试集（测试集没有标签）
        delimiter (str): 文本和标签之间的分隔符（如果适用）

    返回:
        list: 包含数据样本的列表，每个样本是一个字典
              对于训练/验证集: [{'text': str, 'label_text': str}, ...]
              对于测试集: [{'text': str, 'id': int}, ...]
    """
    data_samples = []
    if not os.path.exists(file_path):
        log.error(f"数据文件不存在: {file_path}")
        raise FileNotFoundError(f"Data file not found at {file_path}")

    log.info(f"开始从文件加载数据: {file_path}")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if not line: # 跳过空行
                    continue

                try:
                    if is_test_set:
                        data_samples.append({"text": line, "id": i}) # 为测试集样本添加id
                    else:
                        parts = line.split(delimiter)
                        if len(parts) == 2:
                            text, label_text = parts
                            data_samples.append({"text": text, "label_text": label_text})
                        else:
                            log.warning(f"跳过格式错误的行 {i+1} (0-indexed) in {file_path}: '{line}'. 期望格式: text{delimiter}label")
                except Exception as e_line: # 处理单行解析错误
                    log.warning(f"处理行 {i+1} (0-indexed) 时出错 in {file_path}: '{line}' - 错误: {e_line}")
    except Exception as e_file: # 处理文件读取错误
        log.error(f"读取文件 {file_path} 时发生严重错误: {e_file}")
        raise # 重新抛出异常，让上层处理

    log.info(f"从 {file_path} 成功加载 {len(data_samples)} 条数据")
    return data_samples

# --- 5.2 确保原始数据目录存在 (Ensure Raw Data Directory Exists) ---
# RAW_DATA_DIR 在单元格3中定义
if not os.path.exists(RAW_DATA_DIR):
    os.makedirs(RAW_DATA_DIR)
    log.warning(f"原始数据目录 {RAW_DATA_DIR} 不存在，已自动创建")
    log.warning(f"请确保以下文件已放入该目录中才能继续:")
    log.warning(f"  - {TRAIN_FILE_NAME}")
    log.warning(f"  - {DEV_FILE_NAME}")
    log.warning(f"  - {TEST_FILE_NAME}")
    # 在实际场景中，如果数据是脚本运行的先决条件，这里可能应该直接 exit()
    # exit("错误：原始数据文件缺失，请准备数据后重试")


# --- 5.3 加载所有数据集 (Load All Datasets) ---
# 文件路径 TRAIN_FILE_PATH, DEV_FILE_PATH, TEST_FILE_PATH 在单元格3中定义
try:
    log.info("-" * 20 + " 加载训练集 " + "-" * 20)
    raw_train_data = load_data_from_file(TRAIN_FILE_PATH, is_test_set=False)

    log.info("-" * 20 + " 加载验证集 " + "-" * 20)
    raw_dev_data = load_data_from_file(DEV_FILE_PATH, is_test_set=False)

    log.info("-" * 20 + " 加载测试集 " + "-" * 20)
    raw_test_data = load_data_from_file(TEST_FILE_PATH, is_test_set=True)

except FileNotFoundError as e:
    log.critical(f"关键数据文件加载失败: {e}")
    log.critical("程序无法继续，请检查文件路径和数据是否准备就绪")
    exit(1) # 使用非零退出码表示错误
except Exception as e:
    log.critical(f"加载数据时发生未知严重错误: {e}")
    exit(1)

# --- 5.4 构建类别标签映射 (Build Class Label Mapping) ---
log.info("开始从训练集构建类别标签映射...")
try:
    # 从 raw_train_data 获取标签
    all_train_labels = [item['label_text'] for item in raw_train_data if 'label_text' in item]

    if not raw_train_data:
        log.critical("错误：训练数据集为空，无法构建标签映射")
        exit(1)
    if not all_train_labels:
        log.critical("错误：无法从训练数据中提取到任何标签请检查训练文件格式 (应为 text\\tlabel) 和内容")
        log.critical(f"检查文件: {TRAIN_FILE_PATH}")
        exit(1)

    unique_labels = sorted(list(set(all_train_labels)))
    label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
    id_to_label = {idx: label for label, idx in label_to_id.items()}
    num_classes = len(unique_labels)

    if num_classes == 0:
        log.critical("错误：识别到的类别数量为0，无法继续训练")
        exit(1)
    elif num_classes == 1:
        log.warning("警告：只识别到一个类别模型可能无法有效学习分类任务")


    log.info(f"共找到 {num_classes} 个唯一类别: {unique_labels}")
    log.info(f"类别到ID的映射 (label_to_id): {label_to_id}")
    log.info(f"ID到类别的映射 (id_to_label): {id_to_label}")

except Exception as e:
    log.critical(f"构建标签映射时发生未知错误: {e}")
    exit(1)

log.info("原始数据加载和标签映射完成")
log.info("="*30 + " 数据加载与标签映射结束 " + "="*30 + "\n")

# --- 检查关键变量是否已定义，为后续单元格做准备 ---
# (这些变量应该在此单元格的 try-except 块成功执行后被定义)
assert 'raw_train_data' in locals(), "raw_train_data 未定义"
assert 'raw_dev_data' in locals(), "raw_dev_data 未定义"
assert 'raw_test_data' in locals(), "raw_test_data 未定义"
assert 'label_to_id' in locals(), "label_to_id 未定义"
assert 'id_to_label' in locals(), "id_to_label 未定义"
assert 'num_classes' in locals(), "num_classes 未定义"

2025-05-18 16:52:45,034 - __main__ - INFO - -------------------- 加载训练集 --------------------
2025-05-18 16:52:45,035 - __main__ - INFO - 开始从文件加载数据: ./files/raw_data/train.txt
2025-05-18 16:52:45,692 - __main__ - INFO - 从 ./files/raw_data/train.txt 成功加载 752471 条数据
2025-05-18 16:52:45,692 - __main__ - INFO - -------------------- 加载验证集 --------------------
2025-05-18 16:52:45,693 - __main__ - INFO - 开始从文件加载数据: ./files/raw_data/dev.txt
2025-05-18 16:52:45,762 - __main__ - INFO - 从 ./files/raw_data/dev.txt 成功加载 80000 条数据
2025-05-18 16:52:45,763 - __main__ - INFO - -------------------- 加载测试集 --------------------
2025-05-18 16:52:45,764 - __main__ - INFO - 开始从文件加载数据: ./files/raw_data/test.txt
2025-05-18 16:52:45,814 - __main__ - INFO - 从 ./files/raw_data/test.txt 成功加载 83599 条数据
2025-05-18 16:52:45,815 - __main__ - INFO - 开始从训练集构建类别标签映射...
2025-05-18 16:52:45,885 - __main__ - INFO - 共找到 14 个唯一类别: ['体育', '娱乐', '家居', '彩票', '房产', '教育', '时尚', '时政', '星座', '游戏', '社会', '科技', '股票', '财经']
2025-05-18 16:

## 数据预处理 (Tokenization & Dataset Creation)

In [3]:
# ==============================================================================
# 主要功能:
# 1. 加载预训练模型的Tokenizer
# 2. 定义文本预处理函数，将文本转换为模型可接受的输入格式 (input_ids, attention_mask)，并将标签文本转换为标签ID
# 3. 将原始数据列表转换为 Hugging Face Dataset 对象
# 4. 使用 Dataset.map() 方法对所有数据进行批处理和并行预处理
# 5. 实现预处理后数据的磁盘缓存与加载机制，避免重复处理
# 6. 将处理后的 Dataset 对象设置为 PyTorch Tensor 格式，以便后续 DataLoader 使用
# 7. (可选) 打印一个处理后的样本进行检查
# ==============================================================================
log.info("="*30 + " 开始数据预处理 " + "="*30)

# --- 6.1 加载 Tokenizer (Load Tokenizer) ---

tokenizer = None
log.info(f"尝试加载 Tokenizer...")
if LOCAL_MODEL_PATH and os.path.exists(LOCAL_MODEL_PATH):
    try:
        log.info(f"从本地路径 '{LOCAL_MODEL_PATH}' 加载 Tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
        log.info(f"Tokenizer 从本地路径 '{LOCAL_MODEL_PATH}' 加载成功")
    except Exception as e_local:
        log.warning(f"从本地路径 '{LOCAL_MODEL_PATH}' 加载 Tokenizer 失败: {e_local}")
        log.info(f"将尝试从 Hugging Face Hub 加载 Tokenizer: '{PRETRAINED_MODEL_NAME}'")
        tokenizer = None # 重置以确保尝试从网络加载

if tokenizer is None: # 如果本地加载失败或未提供本地路径
    try:
        log.info(f"从 Hugging Face Hub ({PRETRAINED_MODEL_NAME}) 加载 Tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
        log.info(f"Tokenizer '{PRETRAINED_MODEL_NAME}' 从 Hugging Face Hub 加载成功")
    except Exception as e_hub:
        log.critical(f"从 Hugging Face Hub 加载 Tokenizer '{PRETRAINED_MODEL_NAME}' 失败: {e_hub}")
        log.critical("无法加载Tokenizer，程序将终止")
        exit(1)

# --- 6.2 定义预处理函数 (Define Preprocessing Function) ---
def preprocess_function(
    examples: dict, # Dataset.map 会传入一个批次的样本，其结构为 {'text': [str], 'label_text': [str]} 或 {'text': [str], 'id': [int]}
    tokenizer_instance: AutoTokenizer,
    max_len: int,
    label_map: dict = None,
    is_test_set: bool = False
) -> dict:
    """
    对一批文本样本进行tokenize和编码
    此函数将被 Dataset.map 调用

    参数:
        examples (dict): 包含文本和其他字段的字典 (Dataset.map传入)
        tokenizer_instance (AutoTokenizer): 已初始化的Tokenizer
        max_len (int): Tokenize后的最大序列长度
        label_map (dict, optional): 标签文本到ID的映射对于非测试集是必需的
        is_test_set (bool): 是否为测试集，如果是则没有label

    返回:
        dict: 包含 'input_ids', 'attention_mask' 以及 (如果适用) 'labels' 或 'id' 的字典
    """
    # 对文本进行批处理编码
    # examples['text'] 是一个文本列表
    tokenized_batch = tokenizer_instance(
        examples['text'],
        max_length=max_len,
        padding='max_length', # 填充到max_len
        truncation=True,      # 超出max_len则截断
        return_attention_mask=True,
        return_tensors="pt"   # 返回PyTorch张量，后续Dataset.set_format会处理
    )

    processed_batch = {
        'input_ids': tokenized_batch['input_ids'],
        'attention_mask': tokenized_batch['attention_mask']
    }

    if not is_test_set: # 训练集/验证集情况
        if label_map is None:
            raise ValueError("对于训练/验证集，label_map 不能为空")
        if 'label_text' not in examples:
            log.warning("警告：输入样本批次中缺少 'label_text' 键，但当前不是测试集模式")
            if not examples.get('label_text'): # 检查是否为空列表或 None
                 raise ValueError("训练/验证集样本缺少 'label_text' 或内容为空")

        # 将标签文本转换为ID
        try:
            label_ids = [label_map[label] for label in examples['label_text']]
            processed_batch['labels'] = torch.tensor(label_ids, dtype=torch.long)
        except KeyError as e_key:
            log.error(f"标签映射错误：标签 '{e_key}' 不在 label_map 中请检查训练数据和标签映射的构建")
            raise
        except TypeError as e_type: # 例如 examples['label_text'] 不是可迭代对象或 label_map 不是字典
            log.error(f"处理标签时发生类型错误: {e_type}. examples['label_text']: {examples.get('label_text')}, label_map: {type(label_map)}")
            raise
    else: # 测试集情况
        if 'id' in examples: # Dataset.map会传递原始列
            processed_batch['id'] = examples['id']
        else: log.debug("测试集样本中缺少 'id' 字段")

    return processed_batch

# --- 6.3 定义数据集创建与预处理函数 (Define Dataset Creation & Preprocessing Function) ---
def create_and_process_dataset(
    raw_data_list: list,
    tokenizer_instance: AutoTokenizer,
    max_len: int,
    dataset_display_name: str,
    is_test_set: bool = False,
    label_map: dict = None, # 仅在 is_test_set=False 时需要
    num_processing_workers: int = 4 # 并行处理进程数
) -> Dataset:
    """
    将原始数据列表转换为 Hugging Face Dataset 对象，并对其进行预处理

    参数:
        raw_data_list (list): 包含原始样本字典的列表
        tokenizer_instance (AutoTokenizer): 已初始化的Tokenizer
        max_len (int): Tokenize后的最大序列长度
        dataset_display_name (str): 用于日志和进度的据集名称 (例如 "训练集")
        is_test_set (bool): 是否为测试集
        label_map (dict, optional): 标签文本到ID的映射
        num_processing_workers (int): Dataset.map 使用的进程数

    返回:
        Dataset: 处理后的 Hugging Face Dataset 对象
    """
    if not raw_data_list:
        log.warning(f"原始数据列表 '{dataset_display_name}' 为空，将返回一个空 Dataset")
        return Dataset.from_list([]) # 返回空数据集

    log.info(f"开始为 '{dataset_display_name}' 创建 HuggingFace Dataset 对象...")
    hf_dataset = Dataset.from_list(raw_data_list)
    log.info(f"'{dataset_display_name}' Dataset 对象创建成功，包含 {len(hf_dataset)} 条记录")

    bound_preprocess_function = partial(
        preprocess_function,
        tokenizer_instance=tokenizer_instance,
        max_len=max_len,
        label_map=label_map,
        is_test_set=is_test_set
    )

    log.info(f"开始对 '{dataset_display_name}' 进行并行预处理 (使用 {num_processing_workers} 个进程)...")
    # 使用map方法进行并行预处理
    columns_to_remove = ['text']
    if not is_test_set:
        columns_to_remove.append('label_text')
    columns_to_remove_actual = [col for col in columns_to_remove if col in hf_dataset.column_names]

    processed_dataset = hf_dataset.map(
        bound_preprocess_function,
        batched=True,          # 启用批处理
        batch_size=1000,       
        num_proc=num_processing_workers, # 并行处理的进程数
        remove_columns=columns_to_remove_actual, # 移除原始列
        desc=f"预处理 {dataset_display_name}"
    )
    log.info(f"'{dataset_display_name}' 预处理完成")
    return processed_dataset

# --- 6.4 实例化并缓存/加载 HuggingFace Datasets (Instantiate & Cache/Load Datasets) ---

log.info("开始实例化或加载预处理后的 HuggingFace Dataset 对象...")
try:
    # 检查所有缓存文件是否都存在
    all_cache_exists = (
        os.path.exists(CACHE_TRAIN_PATH) and
        os.path.exists(CACHE_DEV_PATH) and
        os.path.exists(CACHE_TEST_PATH)
    )

    if all_cache_exists:
        log.info("所有预处理好的数据集缓存均存在，正在从磁盘加载...")
        train_dataset = Dataset.load_from_disk(CACHE_TRAIN_PATH)
        dev_dataset = Dataset.load_from_disk(CACHE_DEV_PATH)
        test_dataset = Dataset.load_from_disk(CACHE_TEST_PATH)
        log.info("从磁盘加载预处理数据集成功")
    else:
        log.info("预处理数据集缓存不完整或不存在，将重新处理所有数据集...")

        # 创建并预处理训练集
        train_dataset = create_and_process_dataset(
            raw_train_data, tokenizer, MAX_SEQ_LENGTH,
            dataset_display_name="训练集", is_test_set=False, label_map=label_to_id,
            num_processing_workers=NUM_PROC_FOR_MAP
        )
        log.info(f"保存预处理后的训练集到: {CACHE_TRAIN_PATH}")
        train_dataset.save_to_disk(CACHE_TRAIN_PATH)

        # 创建并预处理验证集
        dev_dataset = create_and_process_dataset(
            raw_dev_data, tokenizer, MAX_SEQ_LENGTH,
            dataset_display_name="验证集", is_test_set=False, label_map=label_to_id,
            num_processing_workers=NUM_PROC_FOR_MAP
        )
        log.info(f"保存预处理后的验证集到: {CACHE_DEV_PATH}")
        dev_dataset.save_to_disk(CACHE_DEV_PATH)

        # 创建并预处理测试集
        test_dataset = create_and_process_dataset(
            raw_test_data, tokenizer, MAX_SEQ_LENGTH,
            dataset_display_name="测试集", is_test_set=True, label_map=None, # 测试集不需要label_map
            num_processing_workers=NUM_PROC_FOR_MAP
        )
        log.info(f"保存预处理后的测试集到: {CACHE_TEST_PATH}")
        test_dataset.save_to_disk(CACHE_TEST_PATH)

        log.info("所有数据集预处理完成并已保存到磁盘缓存")

    log.info("HuggingFace Dataset 对象实例化/加载成功")

    # --- 6.5 设置Dataset格式为PyTorch Tensors (Set Dataset Format to PyTorch Tensors) ---
    log.info("设置数据集格式为 PyTorch Tensors...")
    
    # 训练集和验证集需要 'input_ids', 'attention_mask', 'labels'
    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    dev_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    
    # 测试集需要 'input_ids', 'attention_mask', 以及原始的 'id'
    test_columns_for_torch = ['input_ids', 'attention_mask']
    
    if 'id' in test_dataset.column_names: # 检查 'id' 列是否存在
        test_columns_for_torch.append('id')
    else:
        log.warning("测试数据集中未找到 'id' 列，将不包含在PyTorch格式中如果需要原始ID进行后续关联，请检查预处理逻辑")
    test_dataset.set_format(type='torch', columns=test_columns_for_torch)
    log.info("数据集格式设置完成")

except NameError as e_name:
    log.critical(f"创建 Dataset 时发生 NameError (可能是变量未定义): {e_name}")
    log.critical("请确保以下变量已在之前的单元格中正确加载和定义: "
                 "'raw_train_data', 'raw_dev_data', 'raw_test_data', "
                 "'tokenizer', 'label_to_id', 'MAX_SEQ_LENGTH', "
                 "'CACHE_TRAIN_PATH', 'CACHE_DEV_PATH', 'CACHE_TEST_PATH', 'NUM_PROC_FOR_MAP'.")
    exit(1)
except Exception as e_general:
    log.critical(f"创建或处理 HuggingFace Dataset 对象时发生未知严重错误: {e_general}", exc_info=True) 
    exit(1)


# --- 6.6 (可选) 检查处理后的数据 (Optional: Inspect Processed Data) ---
log.info("-" * 20 + " 检查处理后的训练样本 " + "-" * 20)
if len(train_dataset) > 0:
    sample = train_dataset[0]
    log.info(f"  样本类型: {type(sample)}")
    log.info(f"  样本键: {list(sample.keys())}")
    for key, value in sample.items():
        if isinstance(value, torch.Tensor):
            log.info(f"    '{key}': shape={value.shape}, dtype={value.dtype}")
        else:
            log.info(f"    '{key}': type={type(value)}, value={value}") # 对于非Tensor值也打印出来

    # 尝试解码 Input IDs 进行验证
    try:
        decoded_text_preview = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
        log.info(f"  解码后的 Input IDs (预览): '{decoded_text_preview}'")
        if raw_train_data and isinstance(raw_train_data[0], dict) and 'text' in raw_train_data[0]:
             log.info(f"  对应的原始文本 (预览): '{raw_train_data[0]['text']}'")
    except Exception as e_decode:
        log.warning(f"解码样本Input IDs时出错: {e_decode}")
else:
    log.warning("处理后的训练集为空，无法打印样本进行检查")

log.info("="*30 + " 数据预处理结束 " + "="*30 + "\n")

# --- 确保关键变量已定义 ---
assert 'tokenizer' in locals(), "tokenizer 未定义"
assert 'train_dataset' in locals(), "train_dataset 未定义"
assert 'dev_dataset' in locals(), "dev_dataset 未定义"
assert 'test_dataset' in locals(), "test_dataset 未定义"

2025-05-18 16:52:47,401 - __main__ - INFO - 尝试加载 Tokenizer...
2025-05-18 16:52:47,401 - __main__ - INFO - 从本地路径 '/root/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext-large/snapshots/a25cc9e05974bd9687e528edd516f2cfdb3f5db9' 加载 Tokenizer...
2025-05-18 16:52:47,427 - __main__ - INFO - Tokenizer 从本地路径 '/root/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext-large/snapshots/a25cc9e05974bd9687e528edd516f2cfdb3f5db9' 加载成功
2025-05-18 16:52:47,428 - __main__ - INFO - 开始实例化或加载预处理后的 HuggingFace Dataset 对象...
2025-05-18 16:52:47,429 - __main__ - INFO - 预处理数据集缓存不完整或不存在，将重新处理所有数据集...
2025-05-18 16:52:47,429 - __main__ - INFO - 开始为 '训练集' 创建 HuggingFace Dataset 对象...
2025-05-18 16:52:48,176 - __main__ - INFO - '训练集' Dataset 对象创建成功，包含 752471 条记录
2025-05-18 16:52:48,177 - __main__ - INFO - 开始对 '训练集' 进行并行预处理 (使用 12 个进程)...


预处理 训练集 (num_proc=12):   0%|          | 0/752471 [00:00<?, ? examples/s]

2025-05-18 16:52:59,689 - __main__ - INFO - '训练集' 预处理完成
2025-05-18 16:52:59,694 - __main__ - INFO - 保存预处理后的训练集到: ./files/processed_data_cache/train_dataset_cache


Saving the dataset (0/1 shards):   0%|          | 0/752471 [00:00<?, ? examples/s]

2025-05-18 16:53:00,030 - __main__ - INFO - 开始为 '验证集' 创建 HuggingFace Dataset 对象...
2025-05-18 16:53:00,083 - __main__ - INFO - '验证集' Dataset 对象创建成功，包含 80000 条记录
2025-05-18 16:53:00,083 - __main__ - INFO - 开始对 '验证集' 进行并行预处理 (使用 12 个进程)...


预处理 验证集 (num_proc=12):   0%|          | 0/80000 [00:00<?, ? examples/s]

2025-05-18 16:53:02,532 - __main__ - INFO - '验证集' 预处理完成
2025-05-18 16:53:02,533 - __main__ - INFO - 保存预处理后的验证集到: ./files/processed_data_cache/dev_dataset_cache


Saving the dataset (0/1 shards):   0%|          | 0/80000 [00:00<?, ? examples/s]

2025-05-18 16:53:02,587 - __main__ - INFO - 开始为 '测试集' 创建 HuggingFace Dataset 对象...
2025-05-18 16:53:02,649 - __main__ - INFO - '测试集' Dataset 对象创建成功，包含 83599 条记录
2025-05-18 16:53:02,650 - __main__ - INFO - 开始对 '测试集' 进行并行预处理 (使用 12 个进程)...


预处理 测试集 (num_proc=12):   0%|          | 0/83599 [00:00<?, ? examples/s]

2025-05-18 16:53:05,059 - __main__ - INFO - '测试集' 预处理完成
2025-05-18 16:53:05,060 - __main__ - INFO - 保存预处理后的测试集到: ./files/processed_data_cache/test_dataset_cache


Saving the dataset (0/1 shards):   0%|          | 0/83599 [00:00<?, ? examples/s]

2025-05-18 16:53:05,108 - __main__ - INFO - 所有数据集预处理完成并已保存到磁盘缓存
2025-05-18 16:53:05,108 - __main__ - INFO - HuggingFace Dataset 对象实例化/加载成功
2025-05-18 16:53:05,109 - __main__ - INFO - 设置数据集格式为 PyTorch Tensors...
2025-05-18 16:53:05,110 - __main__ - INFO - 数据集格式设置完成
2025-05-18 16:53:05,110 - __main__ - INFO - -------------------- 检查处理后的训练样本 --------------------
2025-05-18 16:53:05,113 - __main__ - INFO -   样本类型: <class 'dict'>
2025-05-18 16:53:05,113 - __main__ - INFO -   样本键: ['input_ids', 'attention_mask', 'labels']
2025-05-18 16:53:05,114 - __main__ - INFO -     'input_ids': shape=torch.Size([28]), dtype=torch.int64
2025-05-18 16:53:05,114 - __main__ - INFO -     'attention_mask': shape=torch.Size([28]), dtype=torch.int64
2025-05-18 16:53:05,114 - __main__ - INFO -     'labels': shape=torch.Size([]), dtype=torch.int64
2025-05-18 16:53:05,115 - __main__ - INFO -   解码后的 Input IDs (预览): '网 易 第 三 季 度 业 绩 低 于 分 析 师 预 期'
2025-05-18 16:53:05,115 - __main__ - INFO -   对应的原始文本 (预览): '网易第三季度业绩低

## 模型加载 (Model Loading)

In [4]:
# ==============================================================================
# 主要功能:
# 1. 使用 AutoModelForSequenceClassification.from_pretrained() 加载预训练的序列分类模型
#    - 关键参数: num_labels=num_classes，用于初始化与任务类别数匹配的分类头
# 2. 将加载的模型移动到指定的计算设备 (CPU/GPU)
# 3. 对模型加载过程中可能发生的错误 (如路径错误、网络问题、配置不匹配) 进行捕获和日志记录
# ==============================================================================
log.info("="*30 + " 开始模型加载 " + "="*30)

# --- 7.1 检查前置条件 (Check Prerequisites) ---
required_vars_for_model_loading = [
    'num_classes',        # 模型分类头的类别数
    'device',             # 模型将加载到的设备
    'LOCAL_MODEL_PATH',   # 本地模型路径配置
    'PRETRAINED_MODEL_NAME' # 预训练模型名称配置
]
log.info("检查模型加载所需的前置变量...")
for var_name in required_vars_for_model_loading:
    if var_name not in locals() and var_name not in globals(): # 检查局部和全局作用域
        log.critical(f"模型加载失败：关键前置变量 '{var_name}' 未定义，请检查之前的单元格")
        exit(1) 
log.info("所有模型加载所需的前置变量均已定义")

# --- 7.2 模型加载 (Load the Model) ---
log.info("开始加载序列分类模型...")

model = None
model_successfully_loaded_from_path = ""

# 优先尝试从本地路径加载
if LOCAL_MODEL_PATH and os.path.exists(LOCAL_MODEL_PATH): # 确保本地路径非空且存在
    log.info(f"尝试从本地路径 '{LOCAL_MODEL_PATH}' 加载模型 (类别数: {num_classes})...")
    try:
        model = AutoModelForSequenceClassification.from_pretrained(
            LOCAL_MODEL_PATH,
            num_labels=num_classes,
            ignore_mismatched_sizes=False # 如果模型头的尺寸不匹配num_labels，报错
        )
        model_successfully_loaded_from_path = LOCAL_MODEL_PATH
        log.info(f"模型成功从本地路径 '{LOCAL_MODEL_PATH}' 加载")
    except OSError as e_os_local: 
        log.warning(f"从本地路径 '{LOCAL_MODEL_PATH}' 加载模型时发生 OSError: {e_os_local}")
        log.info(f"将尝试从 Hugging Face Hub 使用 '{PRETRAINED_MODEL_NAME}' 加载模型")
        model = None # 重置 model，确保后续尝试从 Hub 加载
    except ValueError as e_val_local: # 通常是 num_labels 与模型头不匹配
        log.warning(f"从本地路径 '{LOCAL_MODEL_PATH}' 加载模型时发生 ValueError: {e_val_local}")
        log.warning("这可能意味着本地模型是一个已微调过的模型，其分类头与当前任务的 num_classes 不匹配")
        log.info(f"将尝试从 Hugging Face Hub 使用 '{PRETRAINED_MODEL_NAME}' 重新初始化模型头并加载")
        model = None
    except Exception as e_local_other: # 其他本地加载错误
        log.warning(f"从本地路径 '{LOCAL_MODEL_PATH}' 加载模型时发生未知错误: {e_local_other}")
        log.info(f"将尝试从 Hugging Face Hub 使用 '{PRETRAINED_MODEL_NAME}' 加载模型")
        model = None

if model is None:
    log.info(f"尝试从 Hugging Face Hub 使用 '{PRETRAINED_MODEL_NAME}' 加载模型 (类别数: {num_classes})...")
    try:
        model = AutoModelForSequenceClassification.from_pretrained(
            PRETRAINED_MODEL_NAME,
            num_labels=num_classes,
            ignore_mismatched_sizes=False 
        )
        model_successfully_loaded_from_path = PRETRAINED_MODEL_NAME
        log.info(f"模型成功从 Hugging Face Hub ('{PRETRAINED_MODEL_NAME}') 加载")
    except OSError as e_os_hub:
        log.critical(f"从 Hugging Face Hub ('{PRETRAINED_MODEL_NAME}') 加载模型时发生 OSError: {e_os_hub}")
        log.critical("可能由网络问题、模型名称错误等导致,程序无法继续")
        exit(1)
    except ValueError as e_val_hub:
        log.critical(f"从 Hugging Face Hub ('{PRETRAINED_MODEL_NAME}') 加载模型时发生 ValueError: {e_val_hub}")
        log.critical(f"可能原因：num_labels ({num_classes}) 与 '{PRETRAINED_MODEL_NAME}' 的预训练分类头（如果存在）不匹配或者模型配置问题，程序无法继续")
        exit(1)
    except Exception as e_hub_other:
        log.critical(f"从 Hugging Face Hub ('{PRETRAINED_MODEL_NAME}') 加载模型时发生未知严重错误: {e_hub_other}", exc_info=True)
        exit(1)

# --- 7.3 将模型移至设备并记录信息 (Move Model to Device & Log Info) ---
try:
    model.to(device)
    log.info(f"模型已成功移至设备: {device}.")
except Exception as e_device:
    log.critical(f"将模型移至设备 {device} 时发生错误: {e_device}", exc_info=True)
    exit(1)

# (可选) 打印模型结构摘要信息
num_model_params = sum(p.numel() for p in model.parameters())
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log.info(f"模型总参数量: {num_model_params:,}")
log.info(f"模型可训练参数量: {num_trainable_params:,}")


log.info(f"模型 '{model_successfully_loaded_from_path}' (基于 '{PRETRAINED_MODEL_NAME}' 结构) 加载并配置完成")
log.info("="*30 + " 模型加载结束 " + "="*30 + "\n")

# --- 确保关键变量已定义 ---
assert 'model' in locals() and model is not None, "模型 (model) 未成功加载或未定义"

2025-05-18 16:53:05,127 - __main__ - INFO - 检查模型加载所需的前置变量...
2025-05-18 16:53:05,128 - __main__ - INFO - 所有模型加载所需的前置变量均已定义
2025-05-18 16:53:05,128 - __main__ - INFO - 开始加载序列分类模型...
2025-05-18 16:53:05,129 - __main__ - INFO - 尝试从本地路径 '/root/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext-large/snapshots/a25cc9e05974bd9687e528edd516f2cfdb3f5db9' 加载模型 (类别数: 14)...
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /root/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext-large/snapshots/a25cc9e05974bd9687e528edd516f2cfdb3f5db9 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-05-18 16:53:07,455 - __main__ - INFO - 模型成功从本地路径 '/root/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext-large/snapshots/a25cc9e05974bd9687e528edd516f2cfdb3f5db9' 加载
2025-05-18 16:53:07,787 - __main__ - INFO - 

## DataLoader 配置 (DataLoader Configuration)

In [5]:
# ==============================================================================
# 主要功能:
# 1. 检查前置步骤定义的关键变量和配置参数
# 2. 创建 DataCollatorWithPadding: 用于在每个批次内部动态填充序列，
#    使其长度一致，以便模型处理
# 3. 实例化 PyTorch DataLoader 对象:
#    - 为训练集 (train_data_loader): 启用数据打乱 (shuffle=True)
#    - 为验证集 (dev_data_loader): 不打乱数据 (shuffle=False)
#    - (可选) 为测试集 (test_data_loader): 不打乱数据 (shuffle=False)
#    - 所有 DataLoader 都使用配置的 batch_size, num_workers, 和 collate_fn
# 4. (可选) 从训练 DataLoader 中获取一个样本批次进行检查，验证其结构和内容
# ==============================================================================
log.info("="*30 + " 开始配置 DataLoaders " + "="*30)

# --- 8.1 检查前置条件 (Check Prerequisites) ---
required_vars_for_dataloader = [
    'tokenizer', 'train_dataset', 'dev_dataset', # test_dataset 是可选的
    'BATCH_SIZE', 'NUM_WORKERS'
]
log.info("检查 DataLoader 配置所需的前置变量...")
for var_name in required_vars_for_dataloader:
    if var_name not in locals() and var_name not in globals():
        # 特别处理可选的 test_dataset
        if var_name == 'test_dataset':
            log.info(f"可选的前置变量 '{var_name}' 未定义，将跳过测试集 DataLoader 的创建")
            test_dataset = None # 确保它被定义为None，以便后续逻辑正确处理
            continue
        log.critical(f"DataLoader 配置失败：关键前置变量 '{var_name}' 未定义，请检查之前的单元格")
        exit(1)
        
if 'test_dataset' not in locals() and 'test_dataset' not in globals():
    test_dataset = None # 确保 test_dataset 有定义，即使是 None

log.info("所有 DataLoader 配置所需的核心前置变量均已定义")
log.info(f"DataLoader 配置参数: Batch Size={BATCH_SIZE}, Num Workers={NUM_WORKERS}")

# --- 8.2 定义数据整理器 (Define Data Collator) ---
# DataCollatorWithPadding 会接收一个批次的数据（字典列表）
# 并使用传入的 tokenizer 来确定如何填充 (padding) input_ids, attention_mask等,使它们在批次内具有相同的长度
# 即使之前的Dataset.map已经做了padding='max_length'到固定长度,
# DataCollatorWithPadding 仍然确保批次正确组合
log.info("创建 DataCollatorWithPadding...")
try:
    collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        padding='longest' # 动态填充到批次内最长序列的长度
    )
    log.info("DataCollatorWithPadding 创建成功")
except Exception as e:
    log.critical(f"创建 DataCollatorWithPadding 失败: {e}", exc_info=True)
    exit(1)

# --- 8.3 实例化训练集 DataLoader (Instantiate Training DataLoader) ---
# `pin_memory=True` 可以加速CPU到GPU的数据传输
# 如果内存不足或在某些特定环境（如CPU-only训练或某些虚拟机），可能需要设为 False
pin_memory_enabled = torch.cuda.is_available() # 仅当CUDA可用时考虑启用
log.info(f"为训练集 DataLoader 设置 pin_memory={pin_memory_enabled}")

try:
    log.info(f"创建训练集 DataLoader (Batch Size: {BATCH_SIZE}, Shuffle: True, Num Workers: {NUM_WORKERS}, Pin Memory: {pin_memory_enabled})")
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collator,
        num_workers=NUM_WORKERS,
        pin_memory=pin_memory_enabled
    )
    log.info("训练集 DataLoader 创建成功")
except Exception as e:
    log.critical(f"创建训练集 DataLoader 失败: {e}", exc_info=True)
    exit(1)

# --- 8.4 实例化验证集 DataLoader (Instantiate Validation DataLoader) ---
try:
    log.info(f"创建验证集 DataLoader (Batch Size: {BATCH_SIZE}, Shuffle: False, Num Workers: {NUM_WORKERS}, Pin Memory: {pin_memory_enabled})")
    dev_data_loader = DataLoader(
        dataset=dev_dataset,
        batch_size=BATCH_SIZE, # 通常验证/测试时可以使用更大的batch size（如果内存允许）
        shuffle=False,
        collate_fn=collator,
        num_workers=NUM_WORKERS,
        pin_memory=pin_memory_enabled
    )
    log.info("验证集 DataLoader 创建成功")
except Exception as e:
    log.critical(f"创建验证集 DataLoader 失败: {e}", exc_info=True)
    exit(1)

# --- 8.5 (可选) 实例化测试集 DataLoader (Optional: Instantiate Test DataLoader) ---
test_data_loader = None # 初始化为 None
if test_dataset is not None and len(test_dataset) > 0:
    try:
        log.info(f"创建测试集 DataLoader (Batch Size: {BATCH_SIZE}, Shuffle: False, Num Workers: {NUM_WORKERS}, Pin Memory: {pin_memory_enabled})")
        test_data_loader = DataLoader(
            dataset=test_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            collate_fn=collator,
            num_workers=NUM_WORKERS,
            pin_memory=pin_memory_enabled
        )
        log.info("测试集 DataLoader 创建成功")
    except Exception as e:
        log.error(f"创建测试集 DataLoader 失败: {e}", exc_info=True) # 使用error而不是critical，因为测试集可能是可选的
        log.warning("后续测试步骤可能无法执行")
        test_data_loader = None # 显式设为 None
else:
    if test_dataset is None:
        log.info("未找到 'test_dataset' 或其为 None，跳过创建测试集 DataLoader")
    elif len(test_dataset) == 0:
        log.warning("'test_dataset' 为空，跳过创建测试集 DataLoader")
    test_data_loader = None

log.info("所有 DataLoader 配置完成")

# --- 8.6 (可选) 检查 DataLoader 输出 (Optional: Inspect DataLoader Output) ---
if train_data_loader is not None:
    log.info("-" * 20 + " 检查训练 DataLoader 输出 " + "-" * 20)
    try:
        sample_batch = next(iter(train_data_loader))
        log.info(f"成功从 train_data_loader 获取一个批次")
        log.info(f"批次包含的键: {list(sample_batch.keys())}")
        for key, value in sample_batch.items():
            if isinstance(value, torch.Tensor):
                log.info(f"  - '{key}': shape={value.shape}, dtype={value.dtype}, device={value.device}")
            else: # 例如 'id' 列表（如果存在且未转为Tensor）
                log.info(f"  - '{key}': type={type(value)}, 示例值 (前5个): {value[:5] if isinstance(value, list) else value}")
        # 预期输出的关键张量:
        # - 'input_ids': shape=[BATCH_SIZE, sequence_length], dtype=torch.int64
        # - 'attention_mask': shape=[BATCH_SIZE, sequence_length], dtype=torch.int64
        # - 'labels' (for train/dev): shape=[BATCH_SIZE], dtype=torch.int64
        # - 'id' (for test, if kept): 可能是列表或Tensor，取决于 collate_fn 和 Dataset 格式
    except StopIteration:
        log.warning("无法从 train_data_loader 获取批次，可能是训练数据集为空")
    except Exception as e:
        log.error(f"尝试从 train_data_loader 获取或检查批次时出错: {e}", exc_info=True)
        if train_dataset is not None and len(train_dataset) > 0:
            log.info(f"打印一个原始训练样本 (HuggingFace Dataset 格式，应用set_format后): {train_dataset[0]}")
else:
    log.warning("train_data_loader 未成功创建或为 None，跳过批次检查")

log.info("DataLoader 配置已准备就绪，后续步骤可以定义优化器、学习率调度器等")
log.info("="*30 + " DataLoaders 配置结束 " + "="*30 + "\n")

# --- 确保关键变量已定义 ---
assert 'collator' in locals(), "collator 未定义"
assert 'train_data_loader' in locals() and train_data_loader is not None, "train_data_loader 未成功创建或未定义"
assert 'dev_data_loader' in locals() and dev_data_loader is not None, "dev_data_loader 未成功创建或未定义"

2025-05-18 16:53:07,812 - __main__ - INFO - 检查 DataLoader 配置所需的前置变量...
2025-05-18 16:53:07,812 - __main__ - INFO - 所有 DataLoader 配置所需的核心前置变量均已定义
2025-05-18 16:53:07,813 - __main__ - INFO - DataLoader 配置参数: Batch Size=256, Num Workers=6
2025-05-18 16:53:07,813 - __main__ - INFO - 创建 DataCollatorWithPadding...
2025-05-18 16:53:07,814 - __main__ - INFO - DataCollatorWithPadding 创建成功
2025-05-18 16:53:07,814 - __main__ - INFO - 为训练集 DataLoader 设置 pin_memory=True
2025-05-18 16:53:07,815 - __main__ - INFO - 创建训练集 DataLoader (Batch Size: 256, Shuffle: True, Num Workers: 6, Pin Memory: True)
2025-05-18 16:53:07,815 - __main__ - INFO - 训练集 DataLoader 创建成功
2025-05-18 16:53:07,816 - __main__ - INFO - 创建验证集 DataLoader (Batch Size: 256, Shuffle: False, Num Workers: 6, Pin Memory: True)
2025-05-18 16:53:07,816 - __main__ - INFO - 验证集 DataLoader 创建成功
2025-05-18 16:53:07,817 - __main__ - INFO - 创建测试集 DataLoader (Batch Size: 256, Shuffle: False, Num Workers: 6, Pin Memory: True)
2025-05-18 16:53:07,817 

## 优化器、学习率调度器与评估指标配置(Optimizer, Learning Rate Scheduler & Metrics Configuration)

In [6]:
# ==============================================================================
# 主要功能:
# 1. 检查前置步骤定义的关键变量和配置参数
# 2. 定义训练超参数 (Learning Rate, Epochs, Warmup Proportion, Weight Decay)
# 3. 创建优化器 (AdamW)
# 4. 根据训练步数和预热比例计算预热步数，并创建学习率调度器
#    (get_linear_schedule_with_warmup)
# 5. 加载评估指标 (Accuracy, F1-Score) 使用 `evaluate` 库
# 6. 将配置的超参数记录到 TensorBoard
# ==============================================================================
log.info("="*30 + " 开始配置优化器、学习率调度器与评估指标 " + "="*30)

# --- 9.1 检查前置条件 (Check Prerequisites) ---
required_vars_for_optimizer_setup = [
    'model', 'train_data_loader',
    'LEARNING_RATE', 'NUM_EPOCHS', 'WARMUP_PROPORTION', 'WEIGHT_DECAY'
]
log.info("检查优化器等配置所需的前置变量...")
for var_name in required_vars_for_optimizer_setup:
    if var_name not in locals() and var_name not in globals():
        log.critical(f"配置失败：关键前置变量 '{var_name}' 未定义，请检查之前的单元格")
        exit(1)
log.info("所有优化器等配置所需的前置变量均已定义")

log.info(f"训练超参数: Learning Rate={LEARNING_RATE}, Epochs={NUM_EPOCHS}, "
         f"Warmup Proportion={WARMUP_PROPORTION}, Weight Decay={WEIGHT_DECAY}")

# --- 9.2 定义优化器 (Define Optimizer) ---
log.info("定义优化器 (AdamW)...")
try:
    optimizer = AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), # 只优化需要梯度的参数
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        eps=1e-8 # AdamW的epsilon参数，增加数值稳定性，Hugging Face Trainer常用此值
    )
    log.info(f"AdamW 优化器创建成功Learning Rate: {LEARNING_RATE}, Weight Decay: {WEIGHT_DECAY}, Eps: 1e-8")
except Exception as e:
    log.critical(f"创建 AdamW 优化器失败: {e}", exc_info=True)
    exit(1)

# --- 9.3 定义学习率调度器 (Define Learning Rate Scheduler) ---
log.info("定义学习率调度器 (get_linear_schedule_with_warmup)...")

# 检查训练数据加载器和Epoch数
if train_data_loader is None or len(train_data_loader) == 0:
    log.critical("训练数据加载器 (train_data_loader) 为空或未定义，无法计算总训练步数")
    exit(1)
if NUM_EPOCHS <= 0:
    log.critical(f"训练轮数 (NUM_EPOCHS) 必须为正数，当前为: {NUM_EPOCHS}")
    exit(1)
if not (0.0 <= WARMUP_PROPORTION <= 1.0):
    log.critical(f"学习率预热比例 (WARMUP_PROPORTION) 必须在 [0.0, 1.0] 之间，当前为: {WARMUP_PROPORTION}")
    exit(1)

num_training_steps = NUM_EPOCHS * len(train_data_loader)
num_warmup_steps = int(WARMUP_PROPORTION * num_training_steps)

if num_warmup_steps > num_training_steps:
    log.warning(f"计算得到的预热步数 ({num_warmup_steps}) 大于总训练步数 ({num_training_steps})，将预热步数调整为总训练步数")
    num_warmup_steps = num_training_steps
if num_training_steps == 0 :
    log.critical("总训练步数为0，无法创建学习率调度器")
    exit(1)

log.info(f"总训练步数 (for LR Scheduler): {num_training_steps}")
log.info(f"预热步数 (Warmup Steps): {num_warmup_steps}")

try:
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )
    log.info("学习率调度器创建成功")
except Exception as e:
    log.critical(f"创建学习率调度器失败: {e}", exc_info=True)
    exit(1)

# --- 9.4 定义评估指标 (Define Evaluation Metrics) ---
log.info("定义评估指标 (Accuracy 和 F1-Score 使用 evaluate.load)...")
accuracy_metric = None
f1_metric = None
try:
    # 尝试加载 accuracy 指标
    try:
        accuracy_metric = evaluate.load("accuracy")
        log.info("Accuracy 指标加载成功")
    except ModuleNotFoundError:
        log.error("Hugging Face 'evaluate' 库或 'accuracy' 指标模块未找到，请确保已正确安装: pip install evaluate scikit-learn") 
        log.warning("将跳过 Accuracy 指标的计算")
    except Exception as e_acc:
        log.error(f"加载 Accuracy 指标失败: {e_acc}")
        log.warning("将跳过 Accuracy 指标的计算")

    # 尝试加载 f1 指标
    try:
        f1_metric = evaluate.load("f1")
        log.info("F1-Score 指标加载成功")
    except ModuleNotFoundError:
        log.error("Hugging Face 'evaluate' 库或 'f1' 指标模块未找到，请确保已正确安装: pip install evaluate scikit-learn") 
        log.warning("将跳过 F1-Score 指标的计算")
    except Exception as e_f1:
        log.error(f"加载 F1-Score 指标失败: {e_f1}")
        log.warning("将跳过 F1-Score 指标的计算")

    if accuracy_metric is None and f1_metric is None:
        log.warning("所有评估指标均加载失败，训练仍可进行，但不会计算验证集指标")
    else:
        log.info("评估指标加载过程完成")

except Exception as e_load_main: # 捕获 evaluate.load 外层可能发生的其他问题
    log.error(f"加载评估指标过程中发生意外错误: {e_load_main}", exc_info=True)
    accuracy_metric = None
    f1_metric = None
    log.warning("所有评估指标均设置为 None")


# --- 9.5 将超参数记录到 TensorBoard ---
# 此处记录的指标只是占位符，实际指标值在训练循环中计算并记录
# 你也可以在训练开始前就这样记录一次配置
try:
    hparams = {
        'learning_rate': LEARNING_RATE,
        'num_epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE, 
        'warmup_proportion': WARMUP_PROPORTION,
        'weight_decay': WEIGHT_DECAY,
        'max_seq_length': MAX_SEQ_LENGTH, 
        'model_name': PRETRAINED_MODEL_NAME 
    }
    # Tensorboard 的 add_hparams 通常期望有一个指标字典，即使是初始值
    # 这里的 metric_dict 是为了满足 API 要求，实际的指标值会在训练中更新
    metric_dict_placeholder = {
        'hparam/accuracy': 0,
        'hparam/f1': 0
    }
    writer.add_hparams(hparams, metric_dict_placeholder)
    log.info(f"超参数已记录到 TensorBoard: {TENSORBOARD_LOG_DIR}")
except Exception as e_tb:
    log.warning(f"记录超参数到 TensorBoard 时发生错误: {e_tb}")


log.info("优化器、学习率调度器、评估指标配置完成")
log.info("="*30 + " 优化器等配置结束 " + "="*30 + "\n")

# --- 确保关键变量已定义 ---
assert 'optimizer' in locals() and optimizer is not None, "优化器 (optimizer) 未成功创建或未定义"
assert 'lr_scheduler' in locals() and lr_scheduler is not None, "学习率调度器 (lr_scheduler) 未成功创建或未定义"

2025-05-18 16:53:08,691 - __main__ - INFO - 检查优化器等配置所需的前置变量...
2025-05-18 16:53:08,691 - __main__ - INFO - 所有优化器等配置所需的前置变量均已定义
2025-05-18 16:53:08,692 - __main__ - INFO - 训练超参数: Learning Rate=2e-05, Epochs=6, Warmup Proportion=0.1, Weight Decay=0.01
2025-05-18 16:53:08,692 - __main__ - INFO - 定义优化器 (AdamW)...
2025-05-18 16:53:08,694 - __main__ - INFO - AdamW 优化器创建成功Learning Rate: 2e-05, Weight Decay: 0.01, Eps: 1e-8
2025-05-18 16:53:08,695 - __main__ - INFO - 定义学习率调度器 (get_linear_schedule_with_warmup)...
2025-05-18 16:53:08,695 - __main__ - INFO - 总训练步数 (for LR Scheduler): 17640
2025-05-18 16:53:08,696 - __main__ - INFO - 预热步数 (Warmup Steps): 1764
2025-05-18 16:53:08,696 - __main__ - INFO - 学习率调度器创建成功
2025-05-18 16:53:08,696 - __main__ - INFO - 定义评估指标 (Accuracy 和 F1-Score 使用 evaluate.load)...
2025-05-18 16:53:08,706 - __main__ - INFO - Accuracy 指标加载成功
2025-05-18 16:53:08,716 - __main__ - INFO - F1-Score 指标加载成功
2025-05-18 16:53:08,717 - __main__ - INFO - 评估指标加载过程完成
2025-05-18 16:53:08,7

## 模型训练与验证循环 (Model Training & Validation Loop)

### eval函数

In [7]:
# ==============================================================================
# 主要功能:
# 1. 实现完整的训练周期 (多 Epochs)
# 2. 在每个 Epoch 内:
#    a. 训练阶段 (Training):
#       - 设置模型为训练模式 (model.train())。
#       - 迭代训练数据加载器 (train_data_loader)。
#       - 执行前向传播、计算损失、反向传播、更新优化器和学习率。
#       - 记录训练损失到日志和 TensorBoard。
#    b. 定期验证阶段 (Intra-Epoch/End-of-Epoch Evaluation):
#       - 根据配置的 EVAL_STRATEGY 和 EVAL_FREQUENCY_FRAC_EPOCH，
#         在训练过程中或每个 Epoch 结束后执行验证。
#       - 设置模型为评估模式 (model.eval())。
#       - 迭代验证数据加载器 (dev_data_loader)。
#       - 计算验证集上的损失、准确率、F1分数等指标。
#       - 记录验证指标到日志和 TensorBoard。
# 3. 模型保存:
#    - 根据验证集上的主要评估指标 (如准确率) 保存表现最佳的模型权重。
# 4. 早停机制 (Early Stopping):
#    - 监控验证集上的指定指标 (EARLY_STOPPING_METRIC)。
#    - 如果指标在连续 EARLY_STOPPING_PATIENCE 次评估中没有明显改善
#      (超过 EARLY_STOPPING_MIN_DELTA)，则提前终止训练。
# 5. 记录整体训练耗时和最终的最佳指标。
# ==============================================================================
log.info("="*30 + " 开始模型训练与验证 " + "="*30)
log.debug(f"evaluate_model: Global accuracy_metric type: {type(accuracy_metric)}, value: {accuracy_metric is not None}")
log.debug(f"evaluate_model: Global f1_metric type: {type(f1_metric)}, value: {f1_metric is not None}")
# --- 10.1 检查前置条件 (Check Prerequisites) ---
required_vars_for_training = [
    'model', 'optimizer', 'lr_scheduler',
    'train_data_loader', 'dev_data_loader',
    'NUM_EPOCHS', 'device', 'log', 'writer', 
    'BEST_MODEL_PATH',
    'EVAL_STRATEGY', 
    'EARLY_STOPPING_PATIENCE','EARLY_STOPPING_METRIC','EARLY_STOPPING_MIN_DELTA'
]
if EVAL_STRATEGY == "steps":
    required_vars_for_training.append('EVAL_FREQUENCY_FRAC_EPOCH')

log.info("检查训练与验证所需的前置变量...")
for var_name in required_vars_for_training:
    if var_name not in locals() and var_name not in globals():
        log.critical(f"训练失败：关键前置变量 '{var_name}' 未定义，请检查之前的单元格。")
        exit(1)
log.info("所有训练与验证所需的前置变量均已定义。")

# --- 10.2 初始化训练状态变量 (Initialize Training State Variables) ---
best_val_metric_for_saving = -float('inf') # 用于保存模型的最佳指标值 (假设是准确率，越大越好)
                                           # 如果是loss，则应为 float('inf')
current_best_model_info = {"epoch": 0, "step":0, "val_accuracy": 0.0, "val_f1": 0.0, "val_loss": float('inf')}


global_step = 0  # 用于 TensorBoard 的全局步数计数
training_completed_normally = False # 标记训练是否正常完成所有epoch

# 早停机制变量初始化
if EARLY_STOPPING_ENABLED:
    log.info(f"早停机制已启用: Patience={EARLY_STOPPING_PATIENCE}, Metric='{EARLY_STOPPING_METRIC}', Min Delta={EARLY_STOPPING_MIN_DELTA}")
    epochs_no_improve = 0 # 连续多少次评估没有改善
    # 根据监控的指标确定初始最佳值和比较模式
    if EARLY_STOPPING_METRIC.endswith('loss'):
        best_early_stopping_metric_val = float('inf')
        early_stopping_mode = 'min'
    else: # accuracy, f1
        best_early_stopping_metric_val = -float('inf')
        early_stopping_mode = 'max'
    # 如果在配置中显式设置了 EARLY_STOPPING_MODE，则覆盖自动推断
    if 'EARLY_STOPPING_MODE' in locals() and EARLY_STOPPING_MODE in ['min', 'max']:
        early_stopping_mode = EARLY_STOPPING_MODE
        if early_stopping_mode == 'min':
            best_early_stopping_metric_val = float('inf')
        else:
            best_early_stopping_metric_val = -float('inf')
    log.info(f"早停监控模式: {early_stopping_mode} (指标: {EARLY_STOPPING_METRIC})")


# 计算按步数验证的频率
eval_every_n_steps = 0
if EVAL_STRATEGY == "steps":
    if not (0 < EVAL_FREQUENCY_FRAC_EPOCH <= 1.0):
        log.warning(f"EVAL_FREQUENCY_FRAC_EPOCH ({EVAL_FREQUENCY_FRAC_EPOCH}) 不在 (0, 1] 范围内，将调整为每个epoch结束时验证。")
        EVAL_STRATEGY = "epoch" # 回退到按epoch验证
    else:
        eval_every_n_steps = max(1, int(len(train_data_loader) * EVAL_FREQUENCY_FRAC_EPOCH))
        log.info(f"按步数验证策略已启用：每训练 {eval_every_n_steps} 步 (约占epoch的 {EVAL_FREQUENCY_FRAC_EPOCH*100:.1f}%) 进行一次验证。")
else:
    log.info("按Epoch结束验证策略已启用。")


# --- 10.3 验证函数定义 (Define Evaluation Function) ---

def evaluate_model(current_model, dataloader, current_device, eval_step_info=""):
    log.info(f"开始验证... {eval_step_info}")
    current_model.eval() # 设置模型为评估模式

    total_eval_loss = 0

    log.debug(f"evaluate_model: Global accuracy_metric type: {type(accuracy_metric)}, value: {accuracy_metric is not None}")
    log.debug(f"evaluate_model: Global f1_metric type: {type(f1_metric)}, value: {f1_metric is not None}")


    can_evaluate_accuracy = accuracy_metric is not None # 直接使用全局变量
    can_evaluate_f1 = f1_metric is not None             # 直接使用全局变量

    if not can_evaluate_accuracy and not can_evaluate_f1:
        log.warning("所有评估指标均未初始化或加载失败，无法计算详细验证指标。仅计算验证损失。")
        # 注意：这里的警告信息可以更精确，因为它现在反映的是全局变量的状态

    val_progress_bar = tqdm(dataloader, desc=f"Validation {eval_step_info}", leave=False, position=0)
    with torch.no_grad(): # 无需计算梯度
        for batch in val_progress_bar:
            input_ids = batch['input_ids'].to(current_device)
            attention_mask = batch['attention_mask'].to(current_device)
            labels = batch['labels'].to(current_device)

            outputs = current_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            total_eval_loss += loss.item()
            predictions = torch.argmax(logits, dim=-1)

            if can_evaluate_accuracy:
                accuracy_metric.add_batch(predictions=predictions.cpu(), references=labels.cpu()) # 使用全局 accuracy_metric
            if can_evaluate_f1:
                f1_metric.add_batch(predictions=predictions.cpu(), references=labels.cpu())       # 使用全局 f1_metric

    val_progress_bar.close()
    avg_val_loss = total_eval_loss / len(dataloader) if len(dataloader) > 0 else float('inf')
    log.info(f"验证完成. 平均验证损失: {avg_val_loss:.4f}")

    val_accuracy = -1.0
    val_f1_weighted = -1.0

    if can_evaluate_accuracy:
        try:
            acc_result = accuracy_metric.compute() # 使用全局 accuracy_metric
            val_accuracy = acc_result['accuracy']
            log.info(f"  验证准确率 (Accuracy): {val_accuracy:.4f}")
        except Exception as e:
            log.error(f"计算验证 Accuracy 时出错: {e}")
    else:
        log.debug("Accuracy 指标未加载或不可用，跳过计算。") # 更新日志信息

    if can_evaluate_f1:
        try:
            f1_result = f1_metric.compute(average="weighted") # 使用全局 f1_metric
            val_f1_weighted = f1_result['f1']
            log.info(f"  验证加权 F1-Score: {val_f1_weighted:.4f}")
        except Exception as e:
            log.error(f"计算验证 F1-Score 时出错: {e}")
    else:
        log.debug("F1-Score 指标未加载或不可用，跳过计算。") # 更新日志信息

    return avg_val_loss, val_accuracy, val_f1_weighted

2025-05-18 16:53:51,500 - __main__ - INFO - 检查训练与验证所需的前置变量...
2025-05-18 16:53:51,501 - __main__ - INFO - 所有训练与验证所需的前置变量均已定义。
2025-05-18 16:53:51,502 - __main__ - INFO - 早停机制已启用: Patience=10, Metric='val_loss', Min Delta=0.001
2025-05-18 16:53:51,502 - __main__ - INFO - 早停监控模式: min (指标: val_loss)
2025-05-18 16:53:51,503 - __main__ - INFO - 按步数验证策略已启用：每训练 588 步 (约占epoch的 20.0%) 进行一次验证。


### 训练主循环

In [8]:
# --- 10.4 训练主循环 (Main Training Loop) ---
global_start_time = time.time()
try:
    for epoch in range(1, NUM_EPOCHS + 1):
        epoch_start_time = time.time()
        log.info(f"--- Epoch {epoch}/{NUM_EPOCHS} ---")
        model.train()
        total_train_loss_epoch = 0
        train_progress_bar = tqdm(train_data_loader, desc=f"Epoch {epoch} Training", leave=True, position=0)

        for step, batch in enumerate(train_progress_bar):
            global_step += 1 # 更新全局步数

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # (可选) 梯度裁剪
            optimizer.step()
            lr_scheduler.step() # 每个step更新学习率

            total_train_loss_epoch += loss.item()
            current_lr = optimizer.param_groups[0]["lr"]

            # 更新tqdm的描述信息
            train_progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss_epoch': f'{total_train_loss_epoch / (step + 1):.4f}',
                'lr': f'{current_lr:.2e}'
            })

            # 记录训练损失到TensorBoard (可以按需调整频率，例如每N步)
            if global_step % (max(1, len(train_data_loader) // 100)) == 0: # 大约记录100次/epoch
                 writer.add_scalar('Loss/train_step', loss.item(), global_step)
                 writer.add_scalar('LearningRate/step', current_lr, global_step)


            # --- 定期验证 (Intra-Epoch Evaluation) ---
            if EVAL_STRATEGY == "steps" and (step + 1) % eval_every_n_steps == 0:
                step_info = f"Epoch {epoch}, Step {step+1}/{len(train_data_loader)}"
                val_loss, val_accuracy, val_f1_weighted = evaluate_model(model, dev_data_loader, device, step_info)
                model.train() # 确保验证后模型回到训练模式

                # 记录验证指标到 TensorBoard
                writer.add_scalar('Loss/validation_step', val_loss, global_step)
                if val_accuracy != -1.0: writer.add_scalar('Accuracy/validation_step', val_accuracy, global_step)
                if val_f1_weighted != -1.0: writer.add_scalar('F1_weighted/validation_step', val_f1_weighted, global_step)

                # 检查是否保存模型 (基于val_accuracy)
                if val_accuracy > best_val_metric_for_saving:
                    best_val_metric_for_saving = val_accuracy
                    current_best_model_info = {"epoch": epoch, "step":step+1, "val_accuracy": val_accuracy, "val_f1": val_f1_weighted, "val_loss": val_loss}
                    log.info(f"🎉 新的最佳模型 (基于Accuracy)! Accuracy: {val_accuracy:.4f} at {step_info}. 保存模型到 {BEST_MODEL_PATH}")
                    torch.save(model.state_dict(), BEST_MODEL_PATH)
                else:
                    log.info(f"当前验证 Accuracy {val_accuracy:.4f} 未超过历史最佳 {best_val_metric_for_saving:.4f} (at {step_info}).")


                # 早停机制检查
                if EARLY_STOPPING_ENABLED:
                    current_metric_for_early_stop = -1.0
                    if EARLY_STOPPING_METRIC == 'val_accuracy': current_metric_for_early_stop = val_accuracy
                    elif EARLY_STOPPING_METRIC == 'val_f1_weighted': current_metric_for_early_stop = val_f1_weighted
                    elif EARLY_STOPPING_METRIC == 'val_loss': current_metric_for_early_stop = val_loss
                    else: # 默认用val_accuracy
                        log.warning(f"未知的早停指标 '{EARLY_STOPPING_METRIC}', 将使用 'val_accuracy'.")
                        current_metric_for_early_stop = val_accuracy

                    improved = False
                    if early_stopping_mode == 'max':
                        if current_metric_for_early_stop > best_early_stopping_metric_val + EARLY_STOPPING_MIN_DELTA:
                            improved = True
                    else: # min mode
                        if current_metric_for_early_stop < best_early_stopping_metric_val - EARLY_STOPPING_MIN_DELTA:
                            improved = True
                    
                    if improved:
                        log.info(f"早停指标改善: {EARLY_STOPPING_METRIC} 从 {best_early_stopping_metric_val:.4f} -> {current_metric_for_early_stop:.4f}. 重置Patience计数器。")
                        best_early_stopping_metric_val = current_metric_for_early_stop
                        epochs_no_improve = 0
                    else:
                        epochs_no_improve += 1
                        log.info(f"早停指标未改善 ({EARLY_STOPPING_METRIC}: {current_metric_for_early_stop:.4f} vs best: {best_early_stopping_metric_val:.4f}). Patience: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}")
                    
                    if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
                        log.info(f"🛑 早停触发! {EARLY_STOPPING_METRIC} 已连续 {epochs_no_improve} 次评估未改善，训练终止。")
                        training_completed_normally = False # 标记为非正常结束
                        raise StopIteration("Early stopping triggered") # 使用 StopIteration 来跳出所有循环

        train_progress_bar.close() # 关闭当前epoch的训练进度条
        avg_epoch_train_loss = total_train_loss_epoch / len(train_data_loader) if len(train_data_loader) > 0 else float('inf')
        log.info(f"Epoch {epoch} 训练阶段完成. 平均训练损失: {avg_epoch_train_loss:.4f}")
        writer.add_scalar('Loss/train_epoch', avg_epoch_train_loss, epoch) # TensorBoard记录epoch平均训练损失

        # --- Epoch结束时验证 (End-of-Epoch Evaluation) ---
        # 如果是按epoch验证，或者按steps验证但当前epoch的最后一步不是验证步，则执行
        perform_epoch_end_eval = (EVAL_STRATEGY == "epoch") or \
                                 (EVAL_STRATEGY == "steps" and len(train_data_loader) % eval_every_n_steps != 0)

        if perform_epoch_end_eval:
            epoch_end_info = f"Epoch {epoch} End"
            val_loss, val_accuracy, val_f1_weighted = evaluate_model(model, dev_data_loader, device, epoch_end_info)
            # 记录验证指标到 TensorBoard
            writer.add_scalar('Loss/validation_epoch', val_loss, epoch)
            if val_accuracy != -1.0: writer.add_scalar('Accuracy/validation_epoch', val_accuracy, epoch)
            if val_f1_weighted != -1.0: writer.add_scalar('F1_weighted/validation_epoch', val_f1_weighted, epoch)

            # 检查是否保存模型 (基于val_accuracy)
            if val_accuracy > best_val_metric_for_saving:
                best_val_metric_for_saving = val_accuracy
                current_best_model_info = {"epoch": epoch, "step":"end_of_epoch", "val_accuracy": val_accuracy, "val_f1": val_f1_weighted, "val_loss": val_loss}
                log.info(f"🎉 新的最佳模型 (基于Accuracy)! Accuracy: {val_accuracy:.4f} at {epoch_end_info}. 保存模型到 {BEST_MODEL_PATH}")
                torch.save(model.state_dict(), BEST_MODEL_PATH)
            else:
                log.info(f"当前验证 Accuracy {val_accuracy:.4f} 未超过历史最佳 {best_val_metric_for_saving:.4f} (at {epoch_end_info}).")

            # 早停机制检查 (如果按epoch验证)
            if EARLY_STOPPING_ENABLED and EVAL_STRATEGY == "epoch": # 只在按epoch验证时，在epoch结束检查早停
                current_metric_for_early_stop = -1.0
                if EARLY_STOPPING_METRIC == 'val_accuracy': current_metric_for_early_stop = val_accuracy
                elif EARLY_STOPPING_METRIC == 'val_f1_weighted': current_metric_for_early_stop = val_f1_weighted
                elif EARLY_STOPPING_METRIC == 'val_loss': current_metric_for_early_stop = val_loss
                else:
                    log.warning(f"未知的早停指标 '{EARLY_STOPPING_METRIC}', 将使用 'val_accuracy'.")
                    current_metric_for_early_stop = val_accuracy

                improved = False
                if early_stopping_mode == 'max':
                    if current_metric_for_early_stop > best_early_stopping_metric_val + EARLY_STOPPING_MIN_DELTA:
                        improved = True
                else: # min mode
                    if current_metric_for_early_stop < best_early_stopping_metric_val - EARLY_STOPPING_MIN_DELTA:
                        improved = True

                if improved:
                    log.info(f"早停指标改善: {EARLY_STOPPING_METRIC} 从 {best_early_stopping_metric_val:.4f} -> {current_metric_for_early_stop:.4f}. 重置Patience计数器。")
                    best_early_stopping_metric_val = current_metric_for_early_stop
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                    log.info(f"早停指标未改善 ({EARLY_STOPPING_METRIC}: {current_metric_for_early_stop:.4f} vs best: {best_early_stopping_metric_val:.4f}). Patience: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}")

                if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
                    log.info(f"🛑 早停触发! {EARLY_STOPPING_METRIC} 已连续 {epochs_no_improve} 次评估未改善，训练终止。")
                    training_completed_normally = False
                    raise StopIteration("Early stopping triggered")

        epoch_duration = time.time() - epoch_start_time
        log.info(f"Epoch {epoch} 总耗时: {epoch_duration:.2f} 秒")
        writer.add_scalar('Time/epoch_duration_seconds', epoch_duration, epoch)

    training_completed_normally = True # 如果循环正常结束，标记为true

except StopIteration as e_stop: # 捕获早停信号
    if str(e_stop) == "Early stopping triggered":
        log.info("训练因早停机制提前结束。")
    else:
        log.error(f"训练意外因 StopIteration 终止: {e_stop}", exc_info=True) # 其他 StopIteration
except KeyboardInterrupt:
    log.warning("训练被用户手动中断 (KeyboardInterrupt)。")
    training_completed_normally = False
except Exception as e_train:
    log.critical(f"训练过程中发生严重错误: {e_train}", exc_info=True)
    training_completed_normally = False
finally:
    # --- 10.5 训练结束总结 (Training Completion Summary) ---
    if training_completed_normally:
        log.info("所有计划的 Epoch 训练与验证完成。")
    else:
        log.info("训练未完成所有计划的 Epoch。")

    log.info(f"训练期间记录的最佳模型信息: Epoch={current_best_model_info['epoch']}, Step={current_best_model_info['step']}, "
             f"Val Accuracy={current_best_model_info['val_accuracy']:.4f}, "
             f"Val F1={current_best_model_info['val_f1']:.4f}, "
             f"Val Loss={current_best_model_info['val_loss']:.4f}")
    if os.path.exists(BEST_MODEL_PATH) and best_val_metric_for_saving > -float('inf'):
        log.info(f"最佳模型参数已保存至: {BEST_MODEL_PATH} (对应验证准确率: {best_val_metric_for_saving:.4f})")
    else:
        log.warning(f"未成功保存任何最佳模型到 {BEST_MODEL_PATH} (可能是因为验证指标从未改善或未进行有效评估)。")

    total_training_time = time.time() - global_start_time
    hours, rem = divmod(total_training_time, 3600)
    minutes, seconds = divmod(rem, 60)
    log.info(f"总训练耗时: {int(hours):02d}小时 {int(minutes):02d}分钟 {seconds:.2f}秒")
    writer.close() # 关闭 TensorBoard writer
    log.info(f"TensorBoard 日志已保存到: {TENSORBOARD_LOG_DIR}")

log.info("="*30 + " 模型训练与验证结束 " + "="*30 + "\n")

# ... 第一次微调训练循环结束 ...

log.info("第一次微调完成，准备进行内存清理...")

if 'optimizer' in locals():
    del optimizer
    log.info("已删除第一次微调的优化器 (optimizer)。")
if 'lr_scheduler' in locals():
    del lr_scheduler
    log.info("已删除第一次微调的学习率调度器 (lr_scheduler)。")

# 2. 删除不再需要的原始训练数据和加载器
if 'train_data_loader' in locals(): 
    del train_data_loader
    log.info("已删除原始训练数据加载器 (train_data_loader)。")
if 'train_dataset' in locals(): 
    del train_dataset
    log.info("已删除原始训练数据集对象 (train_dataset)。")

2025-05-18 16:53:56,385 - __main__ - INFO - --- Epoch 1/6 ---


Epoch 1 Training:   0%|          | 0/2940 [00:00<?, ?it/s]

2025-05-18 17:01:24,592 - __main__ - INFO - 开始验证... Epoch 1, Step 588/2940


Validation Epoch 1, Step 588/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 17:02:51,956 - __main__ - INFO - 验证完成. 平均验证损失: 0.2121
2025-05-18 17:02:52,044 - __main__ - INFO -   验证准确率 (Accuracy): 0.9372
2025-05-18 17:02:52,151 - __main__ - INFO -   验证加权 F1-Score: 0.9371
2025-05-18 17:02:52,153 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9372 at Epoch 1, Step 588/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 17:03:23,418 - __main__ - INFO - 早停指标改善: val_loss 从 inf -> 0.2121. 重置Patience计数器。
2025-05-18 17:10:50,672 - __main__ - INFO - 开始验证... Epoch 1, Step 1176/2940


Validation Epoch 1, Step 1176/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 17:12:18,069 - __main__ - INFO - 验证完成. 平均验证损失: 0.1575
2025-05-18 17:12:18,167 - __main__ - INFO -   验证准确率 (Accuracy): 0.9511
2025-05-18 17:12:18,282 - __main__ - INFO -   验证加权 F1-Score: 0.9512
2025-05-18 17:12:18,285 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9511 at Epoch 1, Step 1176/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 17:12:19,630 - __main__ - INFO - 早停指标改善: val_loss 从 0.2121 -> 0.1575. 重置Patience计数器。
2025-05-18 17:19:47,743 - __main__ - INFO - 开始验证... Epoch 1, Step 1764/2940


Validation Epoch 1, Step 1764/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 17:29:43,644 - __main__ - INFO - 开始验证... Epoch 1, Step 2352/2940


Validation Epoch 1, Step 2352/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 17:31:10,460 - __main__ - INFO - 验证完成. 平均验证损失: 0.1220
2025-05-18 17:31:10,547 - __main__ - INFO -   验证准确率 (Accuracy): 0.9607
2025-05-18 17:31:10,656 - __main__ - INFO -   验证加权 F1-Score: 0.9607
2025-05-18 17:31:10,659 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9607 at Epoch 1, Step 2352/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 17:31:12,590 - __main__ - INFO - 早停指标改善: val_loss 从 0.1400 -> 0.1220. 重置Patience计数器。
2025-05-18 17:39:08,717 - __main__ - INFO - 开始验证... Epoch 1, Step 2940/2940


Validation Epoch 1, Step 2940/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 17:40:36,115 - __main__ - INFO - 验证完成. 平均验证损失: 0.1011
2025-05-18 17:40:36,263 - __main__ - INFO -   验证准确率 (Accuracy): 0.9673
2025-05-18 17:40:36,446 - __main__ - INFO -   验证加权 F1-Score: 0.9673
2025-05-18 17:40:36,449 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9673 at Epoch 1, Step 2940/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 17:40:38,248 - __main__ - INFO - 早停指标改善: val_loss 从 0.1220 -> 0.1011. 重置Patience计数器。
2025-05-18 17:40:38,388 - __main__ - INFO - Epoch 1 训练阶段完成. 平均训练损失: 0.3092
2025-05-18 17:40:38,389 - __main__ - INFO - Epoch 1 总耗时: 2802.00 秒
2025-05-18 17:40:38,389 - __main__ - INFO - --- Epoch 2/6 ---


Epoch 2 Training:   0%|          | 0/2940 [00:00<?, ?it/s]

2025-05-18 17:48:08,298 - __main__ - INFO - 开始验证... Epoch 2, Step 588/2940


Validation Epoch 2, Step 588/2940:   0%|          | 0/313 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torc

Validation Epoch 2, Step 1176/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 17:59:01,039 - __main__ - INFO - 验证完成. 平均验证损失: 0.0819
2025-05-18 17:59:01,126 - __main__ - INFO -   验证准确率 (Accuracy): 0.9735
2025-05-18 17:59:01,238 - __main__ - INFO -   验证加权 F1-Score: 0.9734
2025-05-18 17:59:01,241 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9735 at Epoch 2, Step 1176/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 17:59:02,905 - __main__ - INFO - 早停指标改善: val_loss 从 0.0918 -> 0.0819. 重置Patience计数器。
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_p

Validation Epoch 2, Step 2940/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 18:26:46,369 - __main__ - INFO - 验证完成. 平均验证损失: 0.0617
2025-05-18 18:26:46,453 - __main__ - INFO -   验证准确率 (Accuracy): 0.9798
2025-05-18 18:26:46,578 - __main__ - INFO -   验证加权 F1-Score: 0.9799
2025-05-18 18:26:46,581 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9798 at Epoch 2, Step 2940/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 18:27:17,883 - __main__ - INFO - 早停指标改善: val_loss 从 0.0685 -> 0.0617. 重置Patience计数器。
2025-05-18 18:27:18,013 - __main__ - INFO - Epoch 2 训练阶段完成. 平均训练损失: 0.1079
2025-05-18 18:27:18,014 - __main__ - INFO - Epoch 2 总耗时: 2799.63 秒
2025-05-18 18:27:18,015 - __main__ - INFO - --- Epoch 3/6 ---


Epoch 3 Training:   0%|          | 0/2940 [00:00<?, ?it/s]

2025-05-18 18:35:45,146 - __main__ - INFO - 开始验证... Epoch 3, Step 588/2940


Validation Epoch 3, Step 588/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 18:37:11,327 - __main__ - INFO - 验证完成. 平均验证损失: 0.0581
2025-05-18 18:37:11,421 - __main__ - INFO -   验证准确率 (Accuracy): 0.9804
2025-05-18 18:37:11,541 - __main__ - INFO -   验证加权 F1-Score: 0.9804
2025-05-18 18:37:11,544 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9804 at Epoch 3, Step 588/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 18:37:13,227 - __main__ - INFO - 早停指标改善: val_loss 从 0.0617 -> 0.0581. 重置Patience计数器。
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pi

Validation Epoch 3, Step 1176/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 18:46:34,552 - __main__ - INFO - 验证完成. 平均验证损失: 0.0497
2025-05-18 18:46:34,640 - __main__ - INFO -   验证准确率 (Accuracy): 0.9835
2025-05-18 18:46:34,750 - __main__ - INFO -   验证加权 F1-Score: 0.9835
2025-05-18 18:46:34,753 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9835 at Epoch 3, Step 1176/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 18:46:36,629 - __main__ - INFO - 早停指标改善: val_loss 从 0.0581 -> 0.0497. 重置Patience计数器。
2025-05-18 18:54:02,407 - __main__ - INFO - 开始验证... Epoch 3, Step 1764/2940


Validation Epoch 3, Step 1764/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 18:55:29,599 - __main__ - INFO - 验证完成. 平均验证损失: 0.0443
2025-05-18 18:55:29,686 - __main__ - INFO -   验证准确率 (Accuracy): 0.9853
2025-05-18 18:55:29,796 - __main__ - INFO -   验证加权 F1-Score: 0.9853
2025-05-18 18:55:29,799 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9853 at Epoch 3, Step 1764/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 18:55:31,447 - __main__ - INFO - 早停指标改善: val_loss 从 0.0497 -> 0.0443. 重置Patience计数器。
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_p

Validation Epoch 3, Step 2352/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 19:05:22,090 - __main__ - INFO - 验证完成. 平均验证损失: 0.0418
2025-05-18 19:05:22,177 - __main__ - INFO -   验证准确率 (Accuracy): 0.9861
2025-05-18 19:05:22,287 - __main__ - INFO -   验证加权 F1-Score: 0.9861
2025-05-18 19:05:22,290 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9861 at Epoch 3, Step 2352/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 19:05:53,903 - __main__ - INFO - 早停指标改善: val_loss 从 0.0443 -> 0.0418. 重置Patience计数器。
2025-05-18 19:13:18,902 - __main__ - INFO - 开始验证... Epoch 3, Step 2940/2940


Validation Epoch 3, Step 2940/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 19:14:45,421 - __main__ - INFO - 验证完成. 平均验证损失: 0.0371
2025-05-18 19:14:45,507 - __main__ - INFO -   验证准确率 (Accuracy): 0.9879
2025-05-18 19:14:45,616 - __main__ - INFO -   验证加权 F1-Score: 0.9879
2025-05-18 19:14:45,618 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9879 at Epoch 3, Step 2940/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 19:14:47,255 - __main__ - INFO - 早停指标改善: val_loss 从 0.0418 -> 0.0371. 重置Patience计数器。
2025-05-18 19:14:47,385 - __main__ - INFO - Epoch 3 训练阶段完成. 平均训练损失: 0.0681
2025-05-18 19:14:47,386 - __main__ - INFO - Epoch 3 总耗时: 2849.37 秒
2025-05-18 19:14:47,387 - __main__ - INFO - --- Epoch 4/6 ---


Epoch 4 Training:   0%|          | 0/2940 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>





Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/root/miniconda3/lib/pyt

Validation Epoch 4, Step 588/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 19:24:11,439 - __main__ - INFO - 验证完成. 平均验证损失: 0.0292
2025-05-18 19:24:11,524 - __main__ - INFO -   验证准确率 (Accuracy): 0.9906
2025-05-18 19:24:11,631 - __main__ - INFO -   验证加权 F1-Score: 0.9906
2025-05-18 19:24:11,634 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9906 at Epoch 4, Step 588/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 19:24:12,943 - __main__ - INFO - 早停指标改善: val_loss 从 0.0371 -> 0.0292. 重置Patience计数器。
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pi

Validation Epoch 4, Step 1764/2940:   0%|          | 0/313 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

2025-05-18 19:59:51,221 - __main__ - INFO - 开始验证... Ep

Validation Epoch 4, Step 2940/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 20:01:18,428 - __main__ - INFO - 验证完成. 平均验证损失: 0.0181
2025-05-18 20:01:18,516 - __main__ - INFO -   验证准确率 (Accuracy): 0.9949
2025-05-18 20:01:18,623 - __main__ - INFO -   验证加权 F1-Score: 0.9949
2025-05-18 20:01:18,626 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9949 at Epoch 4, Step 2940/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 20:01:20,257 - __main__ - INFO - 早停指标改善: val_loss 从 0.0222 -> 0.0181. 重置Patience计数器。
2025-05-18 20:01:20,404 - __main__ - INFO - Epoch 4 训练阶段完成. 平均训练损失: 0.0415
2025-05-18 20:01:20,406 - __main__ - INFO - Epoch 4 总耗时: 2793.02 秒
2025-05-18 20:01:20,406 - __main__ - INFO - --- Epoch 5/6 ---


Epoch 5 Training:   0%|          | 0/2940 [00:00<?, ?it/s]

2025-05-18 20:09:47,510 - __main__ - INFO - 开始验证... Epoch 5, Step 588/2940


Validation Epoch 5, Step 588/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 20:11:14,010 - __main__ - INFO - 验证完成. 平均验证损失: 0.0156
2025-05-18 20:11:14,098 - __main__ - INFO -   验证准确率 (Accuracy): 0.9951
2025-05-18 20:11:14,205 - __main__ - INFO -   验证加权 F1-Score: 0.9951
2025-05-18 20:11:14,208 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9951 at Epoch 5, Step 588/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 20:11:15,824 - __main__ - INFO - 早停指标改善: val_loss 从 0.0181 -> 0.0156. 重置Patience计数器。
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pi

Validation Epoch 5, Step 1176/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 20:20:36,903 - __main__ - INFO - 验证完成. 平均验证损失: 0.0151
2025-05-18 20:20:36,991 - __main__ - INFO -   验证准确率 (Accuracy): 0.9956
2025-05-18 20:20:37,097 - __main__ - INFO -   验证加权 F1-Score: 0.9956
2025-05-18 20:20:37,100 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9956 at Epoch 5, Step 1176/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 20:20:38,488 - __main__ - INFO - 早停指标未改善 (val_loss: 0.0151 vs best: 0.0156). Patience: 1/10
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._p

Validation Epoch 5, Step 1764/2940:   0%|          | 0/313 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4f5a976de0>
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/root/miniconda3/lib/python3.12/site-packages/torc

Validation Epoch 6, Step 1764/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 21:17:08,202 - __main__ - INFO - 验证完成. 平均验证损失: 0.0078
2025-05-18 21:17:08,287 - __main__ - INFO -   验证准确率 (Accuracy): 0.9978
2025-05-18 21:17:08,394 - __main__ - INFO -   验证加权 F1-Score: 0.9977
2025-05-18 21:17:08,397 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9978 at Epoch 6, Step 1764/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 21:17:09,810 - __main__ - INFO - 早停指标改善: val_loss 从 0.0092 -> 0.0078. 重置Patience计数器。
2025-05-18 21:24:37,235 - __main__ - INFO - 开始验证... Epoch 6, Step 2352/2940


Validation Epoch 6, Step 2352/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 21:26:04,636 - __main__ - INFO - 验证完成. 平均验证损失: 0.0074
2025-05-18 21:26:04,723 - __main__ - INFO -   验证准确率 (Accuracy): 0.9979
2025-05-18 21:26:04,832 - __main__ - INFO -   验证加权 F1-Score: 0.9979
2025-05-18 21:26:04,835 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9979 at Epoch 6, Step 2352/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 21:26:36,638 - __main__ - INFO - 早停指标未改善 (val_loss: 0.0074 vs best: 0.0078). Patience: 1/10
2025-05-18 21:34:02,566 - __main__ - INFO - 开始验证... Epoch 6, Step 2940/2940


Validation Epoch 6, Step 2940/2940:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 21:35:29,945 - __main__ - INFO - 验证完成. 平均验证损失: 0.0070
2025-05-18 21:35:30,030 - __main__ - INFO -   验证准确率 (Accuracy): 0.9980
2025-05-18 21:35:30,133 - __main__ - INFO -   验证加权 F1-Score: 0.9980
2025-05-18 21:35:30,136 - __main__ - INFO - 🎉 新的最佳模型 (基于Accuracy)! Accuracy: 0.9980 at Epoch 6, Step 2940/2940. 保存模型到 ./files/saved_models/best_roberta_model.pt
2025-05-18 21:35:31,870 - __main__ - INFO - 早停指标未改善 (val_loss: 0.0070 vs best: 0.0078). Patience: 2/10
2025-05-18 21:35:31,990 - __main__ - INFO - Epoch 6 训练阶段完成. 平均训练损失: 0.0141
2025-05-18 21:35:31,991 - __main__ - INFO - Epoch 6 总耗时: 2827.02 秒
2025-05-18 21:35:31,991 - __main__ - INFO - 所有计划的 Epoch 训练与验证完成。
2025-05-18 21:35:31,992 - __main__ - INFO - 训练期间记录的最佳模型信息: Epoch=6, Step=2940, Val Accuracy=0.9980, Val F1=0.9980, Val Loss=0.0070
2025-05-18 21:35:31,992 - __main__ - INFO - 最佳模型参数已保存至: ./files/saved_models/best_roberta_model.pt (对应验证准确率: 0.9980)
2025-05-18 21:35:31,992 - __main__ - INFO - 总训练耗时: 04小时 41分钟 35.61秒
2025-05-1

### 清缓存

In [9]:
# 3. 清空CUDA缓存
if torch.cuda.is_available():
    log.info(f"当前CUDA显存已分配: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    log.info(f"当前CUDA显存已缓存: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    torch.cuda.empty_cache()
    log.info("已尝试清空CUDA缓存。")
    log.info(f"清空缓存后CUDA显存已分配: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    log.info(f"清空缓存后CUDA显存已缓存: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

# 确保 dev_data_loader, model, tokenizer, id_to_label, num_classes, device, writer 等二次微调仍然需要的对象没有被错误删除

2025-05-18 21:35:32,008 - __main__ - INFO - 当前CUDA显存已分配: 2502.73 MB
2025-05-18 21:35:32,009 - __main__ - INFO - 当前CUDA显存已缓存: 16290.00 MB
2025-05-18 21:35:32,258 - __main__ - INFO - 已尝试清空CUDA缓存。
2025-05-18 21:35:32,259 - __main__ - INFO - 清空缓存后CUDA显存已分配: 2502.73 MB
2025-05-18 21:35:32,259 - __main__ - INFO - 清空缓存后CUDA显存已缓存: 3322.00 MB


## 测试集预测结果分析与数据增强样本生成(Test Set Prediction Analysis & Data Augmentation Sample Generation)

In [10]:
# ==============================================================================
# 主要功能:
# 1. 检查前置条件，加载最佳模型权重
# 2. 对整个测试集执行预测，获取每个样本的 top-1 预测标签及其概率
# 3. 将 top-1 预测结果 (原始ID, 文本, 预测标签, 概率) 保存为 JSON 文件
# 4. 统计预测概率分布:
#    a. 以指定的概率分度值 (PROB_DIST_PLOT_STEP) 从 1 到最低概率构建 x 轴
#    b. 计算在每个概率阈值下，有多少比例的样本其 top-1 预测概率高于该阈值 (y 轴)
#    c. 绘制并保存概率分布柱状图
#    d. 将概率分布的具体数值 (概率阈值, 样本占比) 保存为 CSV 文件
# 5. 数据增强:
#    a. 选取 top-1 预测概率排在前 TOP_PERCENT_FOR_AUGMENTATION (例如80%) 的样本
#    b. 将这些高置信度样本以 "标题\t预测的标签" 格式保存到指定路径下的
#       TRAIN_ADD_FILE_NAME 文件中，用于潜在的数据增强
#    c. 记录用于数据增强的样本中，最低的 top-1 预测概率值
# ==============================================================================
log.info("="*30 + " 开始测试集预测分析与数据增强样本生成 " + "="*30)

# --- 11.1 检查前置条件与配置加载 (Check Prerequisites & Load Configs) ---
required_vars_for_analysis = [
    'model', 'id_to_label', 'log', 'device', 'writer', 
    'test_data_loader', 'raw_test_data', 
    'BEST_MODEL_PATH', 
    'RAW_DATA_DIR',
    'TRAIN_ADD_FILE_PATH', 
    'PROB_DIST_PLOT_STEP', 
    'TOP1_JSON_PATH',
    'TOP_PERCENT_FOR_AUGMENTATION' 
]
log.info("检查分析模块所需的前置变量...")
for var_name in required_vars_for_analysis:
    if var_name not in locals() and var_name not in globals():
        log.critical(f"分析模块执行失败：关键前置变量 '{var_name}' 未定义，请检查之前的单元格")
        exit(1)

if test_data_loader is None or len(test_data_loader) == 0:
    log.critical("测试数据加载器 (test_data_loader) 为空或未定义，无法进行预测分析")
    exit(1)
if not raw_test_data:
    log.critical("原始测试数据 (raw_test_data) 为空，无法获取文本内容")
    exit(1)

log.info("所有分析模块所需的前置变量均已定义")


# --- 11.2 加载最佳模型权重 (Load Best Model Weights) ---
log.info(f"加载最佳模型权重从: {BEST_MODEL_PATH}")
if not os.path.exists(BEST_MODEL_PATH):
    log.critical(f"模型权重文件不存在: {BEST_MODEL_PATH}，无法进行预测")
    exit(1)
try:
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
    model.to(device)
    log.info("最佳模型权重加载成功")
except Exception as e:
    log.critical(f"加载模型权重失败: {e}", exc_info=True)
    exit(1)

# --- 11.3 执行全量预测并收集 Top-1 结果 (Perform Full Prediction & Collect Top-1 Results) ---
model.eval() # 设置模型为评估模式
all_top1_results = [] # 存储每个样本的 top-1 预测信息

log.info("开始对所有测试数据进行 Top-1 预测...")
predict_progress_bar = tqdm(test_data_loader, desc="Full Test Prediction (Top-1)", leave=False)

with torch.no_grad():
    for batch_idx, batch in enumerate(predict_progress_bar):
        if 'id' not in batch:
            log.warning(f"批次 {batch_idx} 中缺少 'id' 键，跳过此批次中的部分数据关联")

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1) # (batch_size, num_classes)

        # 获取 Top-1 预测的概率和索引
        top1_probs, top1_indices = torch.max(probabilities, dim=-1) # (batch_size), (batch_size)

        batch_original_ids = batch.get('id', torch.arange(input_ids.size(0)) + batch_idx * test_data_loader.batch_size).cpu().tolist()


        for i in range(top1_probs.size(0)):
            original_id = batch_original_ids[i]
            prob_val = top1_probs[i].item()
            label_idx = top1_indices[i].item()
            predicted_label_text = id_to_label.get(label_idx, "未知标签") # 使用 .get() 避免KeyError

            try:
                # 确保 original_id 是有效的 raw_test_data 索引
                if 0 <= original_id < len(raw_test_data):
                     original_text = raw_test_data[original_id]['text']
                else:
                    log.warning(f"original_id {original_id} 超出 raw_test_data 范围 ({len(raw_test_data)}), 无法获取原始文本")
                    original_text = "原始文本不可用"
            except (IndexError, TypeError) as e_text: # 如果raw_test_data结构不对或id有问题
                log.warning(f"通过 original_id {original_id} 获取原始文本失败: {e_text}")
                original_text = "获取原始文本失败"


            all_top1_results.append({
                "original_id": original_id,
                "text": original_text,
                "predicted_label": predicted_label_text,
                "probability": round(prob_val, 3) # 保留三位小数
            })
predict_progress_bar.close()
log.info(f"Top-1 预测处理完成，共获得 {len(all_top1_results)} 条预测结果")

if not all_top1_results:
    log.critical("未能生成任何 Top-1 预测结果，后续分析无法进行")
    exit(1)

# --- 11.4 保存 Top-1 预测 JSON 结果 (Save Top-1 Prediction JSON Results) ---
try:
    with open(TOP1_JSON_PATH, 'w', encoding='utf-8') as f:
        json.dump(all_top1_results, f, ensure_ascii=False, indent=2)
    log.info(f"全量测试集 Top-1 预测结果已保存至 JSON: {TOP1_JSON_PATH}")
    if all_top1_results:
        log.info("JSON结果示例 (第一条):")
        log.info(json.dumps(all_top1_results[0], ensure_ascii=False, indent=2))
except Exception as e:
    log.error(f"保存 Top-1 预测 JSON 结果时发生错误: {e}", exc_info=True)


# --- 11.5 统计并绘制概率分布图 (Analyze and Plot Probability Distribution) ---
log.info("开始统计和绘制 Top-1 预测概率分布图...")
top1_probabilities = np.array([res['probability'] for res in all_top1_results])

if len(top1_probabilities) == 0:
    log.warning("没有 Top-1 概率数据可供分析，跳过概率分布图绘制")
else:
    min_prob = np.min(top1_probabilities) if len(top1_probabilities) > 0 else 0.0
    # x轴：从1.0到最低概率，以PROB_DIST_PLOT_STEP为步长
    # 确保包含min_prob和1.0，并且步长合理
    prob_thresholds = np.arange(1.0, min_prob - PROB_DIST_PLOT_STEP, -PROB_DIST_PLOT_STEP)
    prob_thresholds = np.clip(prob_thresholds, 0.0, 1.0) # 确保在[0,1]内
    prob_thresholds = np.unique(prob_thresholds)[::-1] # 去重并保持降序，确保1.0在最前
    if prob_thresholds[0] < 1.0: # 确保1.0作为第一个阈值
        prob_thresholds = np.insert(prob_thresholds, 0, 1.0)


    sample_proportions = [] # y轴：高于该概率阈值的样本占比
    for threshold in prob_thresholds:
        count_above_threshold = np.sum(top1_probabilities >= threshold)
        proportion = count_above_threshold / len(top1_probabilities) if len(top1_probabilities) > 0 else 0
        sample_proportions.append(proportion)

    # 绘制柱状图
    try:
        plt.figure(figsize=(12, 7))
        # 为了柱状图美观，x轴标签可能需要调整，如果阈值太多
        bar_positions = np.arange(len(prob_thresholds))
        plt.bar(bar_positions, sample_proportions, width=0.8, color='skyblue')

        plt.xlabel("Top-1 Prediction Probability Threshold (P)")
        plt.ylabel(f"Proportion of Samples with Top-1 Prob >= P (Total Samples: {len(top1_probabilities)})")
        plt.title("Distribution of Top-1 Prediction Probabilities on Test Set")
        
        # 调整x轴刻度标签，避免重叠
        tick_indices = np.linspace(0, len(prob_thresholds) - 1, num=min(15, len(prob_thresholds)), dtype=int) # 最多显示15个刻度
        plt.xticks(bar_positions[tick_indices], [f"{prob_thresholds[i]:.3f}" for i in tick_indices], rotation=45, ha="right")
        
        plt.yticks(np.arange(0, 1.1, 0.1))
        plt.ylim(0, 1.05)
        plt.grid(axis='y', linestyle='--')
        plt.tight_layout()
        plt.savefig(PROB_DIST_PLOT_PATH)
        log.info(f"概率分布柱状图已保存至: {PROB_DIST_PLOT_PATH}")
        plt.close() # 关闭图像，释放内存

        # 将图像记录到TensorBoard
        try:
            image = Image.open(PROB_DIST_PLOT_PATH)
            image_tensor = torch.tensor(np.array(image)).permute(2,0,1) # HWC to CHW
            writer.add_image('Analysis/ProbabilityDistribution', image_tensor, global_step=0) # global_step可以设为0或某个特定值
            log.info("概率分布图已尝试记录到 TensorBoard")
        except Exception as e_tb_img:
            log.warning(f"记录概率分布图到 TensorBoard 失败: {e_tb_img}")


    except Exception as e_plot:
        log.error(f"绘制或保存概率分布图时发生错误: {e_plot}", exc_info=True)

    # 保存概率分布数据到CSV
    try:
        with open(PROB_DIST_CSV_PATH, 'w', newline='', encoding='utf-8') as csvfile:
            csv_writer = csv.writer(csvfile)
            csv_writer.writerow(['Probability_Threshold', 'Proportion_Samples_Above_Threshold'])
            for threshold, proportion in zip(prob_thresholds, sample_proportions):
                csv_writer.writerow([f"{threshold:.3f}", f"{proportion:.4f}"])
        log.info(f"概率分布数据已保存至 CSV: {PROB_DIST_CSV_PATH}")
    except Exception as e_csv:
        log.error(f"保存概率分布数据到 CSV 时发生错误: {e_csv}", exc_info=True)


# --- 11.6 生成数据增强文件 (Generate Data Augmentation File) ---
log.info(f"开始根据 Top-1 预测概率生成数据增强文件 (选取前 {TOP_PERCENT_FOR_AUGMENTATION*100:.0f}% 高置信度样本)...")

# 按概率降序排序所有 Top-1 结果
sorted_top1_results = sorted(all_top1_results, key=lambda x: x['probability'], reverse=True)

num_to_select_for_aug = int(len(sorted_top1_results) * TOP_PERCENT_FOR_AUGMENTATION)
if num_to_select_for_aug == 0 and len(sorted_top1_results) > 0: # 确保至少选一个，如果比例太小但有数据
    log.warning(f"根据比例 {TOP_PERCENT_FOR_AUGMENTATION} 计算得到的数据增强样本数为0，将至少选择一个样本（如果存在）")
    # num_to_select_for_aug = 1 # 或者保持为0，不生成文件

augmented_samples = sorted_top1_results[:num_to_select_for_aug]

if augmented_samples:
    min_prob_for_augmentation = augmented_samples[-1]['probability'] # 获取选中样本中的最低概率
    log.info(f"将选取 {len(augmented_samples)} 条样本用于数据增强")
    log.info(f"  这些样本的 Top-1 预测概率范围从 {augmented_samples[0]['probability']:.3f} 到 {min_prob_for_augmentation:.3f}")

    try:
        count_written = 0
        with open(TRAIN_ADD_FILE_PATH, 'w', encoding='utf-8') as f_aug:
            for sample in augmented_samples:
                title = sample['text'].replace('\t', ' ').replace('\n', ' ') # 替换制表符和换行符
                predicted_label = sample['predicted_label']
                f_aug.write(f"{title}\t{predicted_label}\n")
                count_written +=1
        log.info(f"成功将 {count_written} 条预测写入数据增强文件: {TRAIN_ADD_FILE_PATH}")
        if count_written > 0:
            log.info("数据增强文件内容示例 (前3行):")
            with open(TRAIN_ADD_FILE_PATH, 'r', encoding='utf-8') as f_preview:
                for i, line in enumerate(f_preview):
                    if i < 3:
                        log.info(f"  {line.strip()}")
                    else:
                        break
    except Exception as e_aug:
        log.error(f"写入数据增强文件 {TRAIN_ADD_FILE_PATH} 时发生错误: {e_aug}", exc_info=True)
else:
    log.info("没有足够的样本满足数据增强条件，或未生成任何预测结果，未创建数据增强文件")
log.info("="*30 + " 测试集预测分析与数据增强样本生成结束 " + "="*30 + "\n")

2025-05-18 21:36:02,318 - __main__ - INFO - 检查分析模块所需的前置变量...
2025-05-18 21:36:02,319 - __main__ - INFO - 所有分析模块所需的前置变量均已定义
2025-05-18 21:36:02,320 - __main__ - INFO - 加载最佳模型权重从: ./files/saved_models/best_roberta_model.pt
  model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
2025-05-18 21:36:03,128 - __main__ - INFO - 最佳模型权重加载成功
2025-05-18 21:36:03,131 - __main__ - INFO - 开始对所有测试数据进行 Top-1 预测...


Full Test Prediction (Top-1):   0%|          | 0/327 [00:00<?, ?it/s]

2025-05-18 21:37:34,207 - __main__ - INFO - Top-1 预测处理完成，共获得 83599 条预测结果
2025-05-18 21:37:34,635 - __main__ - INFO - 全量测试集 Top-1 预测结果已保存至 JSON: ./files/analysis_results/all_test_analysis/test_top1_predictions.json
2025-05-18 21:37:34,635 - __main__ - INFO - JSON结果示例 (第一条):
2025-05-18 21:37:34,636 - __main__ - INFO - {
  "original_id": 0,
  "text": "北京君太百货璀璨秋色 满100省353020元",
  "predicted_label": "房产",
  "probability": 0.999
}
2025-05-18 21:37:34,636 - __main__ - INFO - 开始统计和绘制 Top-1 预测概率分布图...
2025-05-18 21:38:05,847 - __main__ - INFO - 概率分布柱状图已保存至: ./files/analysis_results/all_test_analysis/test_prob_distribution.png
2025-05-18 21:38:05,969 - __main__ - INFO - 概率分布图已尝试记录到 TensorBoard
2025-05-18 21:38:05,981 - __main__ - INFO - 概率分布数据已保存至 CSV: ./files/analysis_results/all_test_analysis/test_prob_distribution_data.csv
2025-05-18 21:38:05,982 - __main__ - INFO - 开始根据 Top-1 预测概率生成数据增强文件 (选取前 80% 高置信度样本)...
2025-05-18 21:38:05,996 - __main__ - INFO - 将选取 66879 条样本用于数据增强
2025-05-18 21:38:05,

## 增强训练

In [13]:
# ==============================================================================
# 单元格 12: 使用增强数据进行二次微调 (Second Stage Fine-tuning with Augmented Data)
# 承接之前生成的数据增强文件 (TRAIN_ADD_FILE_PATH)，
# 以及第一阶段微调得到的最佳模型 (BEST_MODEL_PATH)。
#
# 主要功能:
# 1. 检查前置条件和配置文件。
# 2. 加载由高置信度预测生成的增强训练数据。
# 3. 对增强数据进行预处理 (Tokenization, Dataset creation)，可选择使用新的缓存路径。
# 4. 创建用于增强数据训练的 DataLoader。
# 5. 加载第一阶段微调后的最佳模型权重。
# 6. (可选) 为二次微调配置新的（通常更小的）学习率、优化器和学习率调度器。
# 7. 执行训练和验证循环：
#    - 使用增强数据进行训练。
#    - 在原始验证集 (dev_data_loader) 上进行评估。
#    - 无早停机制。
#    - 记录训练过程到 TensorBoard。
# 8. (可选) 保存二次微调后的最佳模型。
# ==============================================================================
log.info("="*30 + " 开始使用增强数据进行二次微调 " + "="*30)

# --- 12.1 检查前置条件与配置加载 (Check Prerequisites & Load Configs) ---
required_vars_for_aug_finetuning = [
    'model', 'tokenizer', 'id_to_label', 'num_classes', 'device', 'log', 'writer',
    'dev_data_loader', # 使用原始验证集进行评估
    'AUG_TRAIN_FILE_PATH', # 增强数据文件路径 (原TRAIN_ADD_FILE_PATH)
    'CACHE_AUG_TRAIN_PATH', # 增强数据缓存路径 (新)
    'BEST_MODEL_PATH',      # 初次微调的最佳模型
    'BEST_AUG_MODEL_PATH',  # 二次微调模型保存路径 (新)
    'AUG_LEARNING_RATE', 'AUG_NUM_EPOCHS', 'AUG_WARMUP_PROPORTION', 'AUG_WEIGHT_DECAY',
    'BATCH_SIZE', 'NUM_WORKERS', 'MAX_SEQ_LENGTH', 'NUM_PROC_FOR_MAP',
    'EVAL_STRATEGY', # 复用或定义 AUG_EVAL_STRATEGY
]
if EVAL_STRATEGY == "steps": # 如果按步验证，也需要频率参数
    required_vars_for_aug_finetuning.append('EVAL_FREQUENCY_FRAC_EPOCH') # 复用或定义 AUG_EVAL_FREQUENCY_FRAC_EPOCH

log.info("检查二次微调所需的前置变量...")
for var_name in required_vars_for_aug_finetuning:
    if var_name not in locals() and var_name not in globals():
        log.critical(f"二次微调失败：关键前置变量 '{var_name}' 未定义。请检查之前的单元格和配置。")
        exit(1)
log.info("所有二次微调所需的前置变量均已定义。")

# --- 12.2 加载数据增强文件 (Load Augmented Data File) ---
log.info(f"开始加载数据增强文件: {AUG_TRAIN_FILE_PATH}")
if not os.path.exists(AUG_TRAIN_FILE_PATH):
    log.critical(f"数据增强文件 {AUG_TRAIN_FILE_PATH} 不存在。无法进行二次微调。")
    exit(1)

# 使用之前定义的 load_data_from_file 函数加载，格式为 "text\tlabel"
try:
    raw_aug_train_data = load_data_from_file(AUG_TRAIN_FILE_PATH, is_test_set=False)
    if not raw_aug_train_data:
        log.critical(f"从 {AUG_TRAIN_FILE_PATH} 加载的数据为空。无法进行二次微调。")
        exit(1)
    log.info(f"成功从 {AUG_TRAIN_FILE_PATH} 加载 {len(raw_aug_train_data)} 条增强训练数据。")
except Exception as e:
    log.critical(f"加载数据增强文件 {AUG_TRAIN_FILE_PATH} 失败: {e}", exc_info=True)
    exit(1)

# --- 12.3 预处理增强数据 (Preprocess Augmented Data) ---
# 使用与之前类似的预处理流程，但使用新的缓存路径
log.info("开始预处理增强训练数据...")
aug_train_dataset = None
try:
    if os.path.exists(CACHE_AUG_TRAIN_PATH):
        log.info(f"发现预处理好的增强训练数据缓存，正在从磁盘加载: {CACHE_AUG_TRAIN_PATH}")
        aug_train_dataset = Dataset.load_from_disk(CACHE_AUG_TRAIN_PATH)
        log.info("从磁盘加载预处理增强训练数据成功。")
    else:
        log.info(f"未找到增强训练数据缓存，正在重新处理: {CACHE_AUG_TRAIN_PATH}")
        aug_train_dataset = create_and_process_dataset( # 使用单元格7定义的函数
            raw_data_list=raw_aug_train_data,
            tokenizer_instance=tokenizer,
            max_len=MAX_SEQ_LENGTH,
            dataset_display_name="增强训练集",
            is_test_set=False,
            label_map=label_to_id, # 使用原始的label_to_id映射
            num_processing_workers=NUM_PROC_FOR_MAP
        )
        log.info(f"保存预处理后的增强训练集到: {CACHE_AUG_TRAIN_PATH}")
        aug_train_dataset.save_to_disk(CACHE_AUG_TRAIN_PATH)
        log.info("增强训练数据预处理完成并已保存到磁盘缓存。")

    # 设置Dataset格式为PyTorch Tensors
    aug_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    log.info("增强训练数据集格式已设置为 PyTorch Tensors。")

except Exception as e:
    log.critical(f"预处理增强训练数据时发生严重错误: {e}", exc_info=True)
    exit(1)

# --- 12.4 创建增强数据的 DataLoader (Create DataLoader for Augmented Data) ---
log.info("创建增强训练数据的 DataLoader...")
pin_memory_enabled_aug = torch.cuda.is_available()
try:
    aug_train_data_loader = DataLoader(
        dataset=aug_train_dataset,
        batch_size=BATCH_SIZE, # 可以使用与初次训练相同的BATCH_SIZE，或单独配置
        shuffle=True,
        collate_fn=collator, # 复用之前创建的collator
        #num_workers=NUM_WORKERS, 
        num_workers = 0,
        pin_memory=pin_memory_enabled_aug
    )
    log.info(f"增强训练数据 DataLoader 创建成功 (Batch Size: {BATCH_SIZE}, Num Workers: {0})。")
    if len(aug_train_data_loader) == 0:
        log.critical("增强训练数据 DataLoader 为空，无法进行二次微调。")
        exit(1)
except Exception as e:
    log.critical(f"创建增强训练数据 DataLoader 失败: {e}", exc_info=True)
    exit(1)

# --- 12.5 加载已微调的模型 (Load Fine-tuned Model) ---
log.info(f"加载第一阶段微调后的最佳模型权重从: {BEST_MODEL_PATH}")
# model 对象应该仍然是之前训练/加载的那个实例，我们直接加载状态字典
try:
    #model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
    model.load_state_dict(torch.load(BEST_AUG_MODEL_PATH, map_location=device))
    model.to(device) # 确保模型在正确的设备上
    log.info(f"成功加载来自 '{BEST_MODEL_PATH}' 的模型权重。")
except Exception as e:
    log.critical(f"加载模型权重 {BEST_MODEL_PATH} 失败: {e}", exc_info=True)
    exit(1)

# --- 12.6 配置二次微调的优化器和学习率调度器 (Configure Optimizer & Scheduler for Augmentation Fine-tuning) ---
log.info("为二次微调配置新的优化器和学习率调度器...")
log.info(f"二次微调超参数: Learning Rate={AUG_LEARNING_RATE}, Epochs={AUG_NUM_EPOCHS}, "
         f"Warmup Proportion={AUG_WARMUP_PROPORTION}, Weight Decay={AUG_WEIGHT_DECAY}")
try:
    aug_optimizer = AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=AUG_LEARNING_RATE,
        weight_decay=AUG_WEIGHT_DECAY,
        eps=1e-8
    )
    log.info("二次微调 AdamW 优化器创建成功。")

    aug_num_training_steps = AUG_NUM_EPOCHS * len(aug_train_data_loader)
    aug_num_warmup_steps = int(AUG_WARMUP_PROPORTION * aug_num_training_steps)
    if aug_num_training_steps == 0:
        log.critical("二次微调的总训练步数为0，无法创建学习率调度器。")
        exit(1)

    log.info(f"二次微调总训练步数: {aug_num_training_steps}, 预热步数: {aug_num_warmup_steps}")
    aug_lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=aug_optimizer,
        num_warmup_steps=aug_num_warmup_steps,
        num_training_steps=aug_num_training_steps
    )
    log.info("二次微调学习率调度器创建成功。")
except Exception as e:
    log.critical(f"创建二次微调优化器或学习率调度器失败: {e}", exc_info=True)
    exit(1)

# --- 12.7 执行二次微调训练与验证循环 (Execute Second Stage Fine-tuning Loop) ---
log.info("开始二次微调训练与验证循环...")

best_aug_val_metric_for_saving = -float('inf') # 用于保存二次微调模型的最佳指标值
current_best_aug_model_info = {"epoch": 0, "step":0, "val_accuracy": 0.0, "val_f1": 0.0, "val_loss": float('inf')}
global_aug_step = 0 # 用于二次微调的 TensorBoard 全局步数

# 根据 EVAL_STRATEGY 计算验证频率 (如果使用基于steps的验证)
aug_eval_every_n_steps = 0
# current_eval_strategy = getattr('AUG_EVAL_STRATEGY', EVAL_STRATEGY) # 允许为增强训练单独配置策略
# current_eval_freq_frac = getattr('AUG_EVAL_FREQUENCY_FRAC_EPOCH', EVAL_FREQUENCY_FRAC_EPOCH)
current_eval_strategy = globals().get('AUG_EVAL_STRATEGY', EVAL_STRATEGY)
current_eval_freq_frac = globals().get('AUG_EVAL_FREQUENCY_FRAC_EPOCH', EVAL_FREQUENCY_FRAC_EPOCH)
if current_eval_strategy == "steps":
    if not (0 < current_eval_freq_frac <= 1.0):
        log.warning(f"二次微调的 EVAL_FREQUENCY_FRAC_EPOCH ({current_eval_freq_frac}) 无效，将按epoch结束时验证。")
        current_eval_strategy = "epoch"
    else:
        aug_eval_every_n_steps = max(1, int(len(aug_train_data_loader) * current_eval_freq_frac))
        log.info(f"二次微调按步数验证：每 {aug_eval_every_n_steps} 步 (约占epoch的 {current_eval_freq_frac*100:.1f}%) 进行一次验证。")
else:
    log.info("二次微调按Epoch结束验证策略。")


global_aug_start_time = time.time()
training_aug_completed_normally = False

try:
    for epoch in range(1, AUG_NUM_EPOCHS + 1):
        epoch_aug_start_time = time.time()
        log.info(f"--- 二次微调 Epoch {epoch}/{AUG_NUM_EPOCHS} ---")
        model.train()
        total_train_loss_aug_epoch = 0
        aug_train_progress_bar = tqdm(aug_train_data_loader, desc=f"Aug Epoch {epoch} Training", leave=True, position=0)

        for step, batch in enumerate(aug_train_progress_bar):
            global_aug_step += 1

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            aug_optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # (可选) 梯度裁剪
            aug_optimizer.step()
            aug_lr_scheduler.step()

            total_train_loss_aug_epoch += loss.item()
            current_aug_lr = aug_optimizer.param_groups[0]["lr"]

            aug_train_progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss_epoch': f'{total_train_loss_aug_epoch / (step + 1):.4f}',
                'lr': f'{current_aug_lr:.2e}'
            })

            if global_aug_step % (max(1, len(aug_train_data_loader) // 100)) == 0:
                writer.add_scalar('Loss/AugTrain_step', loss.item(), global_aug_step)
                writer.add_scalar('LearningRate/Aug_step', current_aug_lr, global_aug_step)

            # --- 定期验证 (二次微调期间) ---
            if current_eval_strategy == "steps" and (step + 1) % aug_eval_every_n_steps == 0:
                step_info = f"Aug Epoch {epoch}, Step {step+1}/{len(aug_train_data_loader)}"
                # 使用原始的 dev_data_loader 进行验证
                val_loss, val_accuracy, val_f1_weighted = evaluate_model(model, dev_data_loader, device, step_info)
                model.train() # 确保验证后模型回到训练模式

                writer.add_scalar('Loss/AugValidation_step', val_loss, global_aug_step)
                if val_accuracy != -1.0: writer.add_scalar('Accuracy/AugValidation_step', val_accuracy, global_aug_step)
                if val_f1_weighted != -1.0: writer.add_scalar('F1_weighted/AugValidation_step', val_f1_weighted, global_aug_step)

                if val_accuracy > best_aug_val_metric_for_saving:
                    best_aug_val_metric_for_saving = val_accuracy
                    current_best_aug_model_info = {"epoch": epoch, "step":step+1, "val_accuracy": val_accuracy, "val_f1": val_f1_weighted, "val_loss": val_loss}
                    log.info(f"🎉 新的最佳增强模型 (基于Accuracy)! Accuracy: {val_accuracy:.4f} at {step_info}. 保存模型到 {BEST_AUG_MODEL_PATH}")
                    torch.save(model.state_dict(), BEST_AUG_MODEL_PATH)
                else:
                    log.info(f"当前验证 Accuracy {val_accuracy:.4f} 未超过二次微调历史最佳 {best_aug_val_metric_for_saving:.4f} (at {step_info}).")

        aug_train_progress_bar.close()
        avg_epoch_aug_train_loss = total_train_loss_aug_epoch / len(aug_train_data_loader) if len(aug_train_data_loader) > 0 else float('inf')
        log.info(f"二次微调 Epoch {epoch} 训练阶段完成. 平均训练损失: {avg_epoch_aug_train_loss:.4f}")
        writer.add_scalar('Loss/AugTrain_epoch', avg_epoch_aug_train_loss, epoch)


        # --- Epoch结束时验证 (二次微调期间) ---
        perform_aug_epoch_end_eval = (current_eval_strategy == "epoch") or \
                                     (current_eval_strategy == "steps" and len(aug_train_data_loader) % aug_eval_every_n_steps != 0)

        if perform_aug_epoch_end_eval:
            epoch_end_info = f"Aug Epoch {epoch} End"
            val_loss, val_accuracy, val_f1_weighted = evaluate_model(model, dev_data_loader, device, epoch_end_info)

            writer.add_scalar('Loss/AugValidation_epoch', val_loss, epoch)
            if val_accuracy != -1.0: writer.add_scalar('Accuracy/AugValidation_epoch', val_accuracy, epoch)
            if val_f1_weighted != -1.0: writer.add_scalar('F1_weighted/AugValidation_epoch', val_f1_weighted, epoch)

            if val_accuracy > best_aug_val_metric_for_saving:
                best_aug_val_metric_for_saving = val_accuracy
                current_best_aug_model_info = {"epoch": epoch, "step":"end_of_epoch", "val_accuracy": val_accuracy, "val_f1": val_f1_weighted, "val_loss": val_loss}
                log.info(f"🎉 新的最佳增强模型 (基于Accuracy)! Accuracy: {val_accuracy:.4f} at {epoch_end_info}. 保存模型到 {BEST_AUG_MODEL_PATH}")
                torch.save(model.state_dict(), BEST_AUG_MODEL_PATH)
            else:
                log.info(f"当前验证 Accuracy {val_accuracy:.4f} 未超过二次微调历史最佳 {best_aug_val_metric_for_saving:.4f} (at {epoch_end_info}).")

        epoch_aug_duration = time.time() - epoch_aug_start_time
        log.info(f"二次微调 Epoch {epoch} 总耗时: {epoch_aug_duration:.2f} 秒")
        writer.add_scalar('Time/Aug_epoch_duration_seconds', epoch_aug_duration, epoch)

    training_aug_completed_normally = True

except KeyboardInterrupt:
    log.warning("二次微调被用户手动中断 (KeyboardInterrupt)。")
    training_aug_completed_normally = False
except Exception as e_train_aug:
    log.critical(f"二次微调过程中发生严重错误: {e_train_aug}", exc_info=True)
    training_aug_completed_normally = False
finally:
    if training_aug_completed_normally:
        log.info("所有计划的二次微调 Epoch 训练与验证完成。")
    else:
        log.info("二次微调训练未完成所有计划的 Epoch。")

    log.info(f"二次微调期间记录的最佳模型信息: Epoch={current_best_aug_model_info['epoch']}, Step={current_best_aug_model_info['step']}, "
             f"Val Accuracy={current_best_aug_model_info['val_accuracy']:.4f}, "
             f"Val F1={current_best_aug_model_info['val_f1']:.4f}, "
             f"Val Loss={current_best_aug_model_info['val_loss']:.4f}")
    if os.path.exists(BEST_AUG_MODEL_PATH) and best_aug_val_metric_for_saving > -float('inf'):
        log.info(f"最佳二次微调模型参数已保存至: {BEST_AUG_MODEL_PATH} (对应验证准确率: {best_aug_val_metric_for_saving:.4f})")
    else:
        log.warning(f"未成功保存任何最佳二次微调模型到 {BEST_AUG_MODEL_PATH}。")

    total_aug_training_time = time.time() - global_aug_start_time
    hours, rem = divmod(total_aug_training_time, 3600)
    minutes, seconds = divmod(rem, 60)
    log.info(f"总二次微调耗时: {int(hours):02d}小时 {int(minutes):02d}分钟 {seconds:.2f}秒")

log.info("="*30 + " 使用增强数据进行二次微调结束 " + "="*30 + "\n")


2025-05-18 22:00:51,052 - __main__ - INFO - 检查二次微调所需的前置变量...
2025-05-18 22:00:51,052 - __main__ - INFO - 所有二次微调所需的前置变量均已定义。
2025-05-18 22:00:51,053 - __main__ - INFO - 开始加载数据增强文件: ./files/raw_data/train_add.txt
2025-05-18 22:00:51,054 - __main__ - INFO - 开始从文件加载数据: ./files/raw_data/train_add.txt
2025-05-18 22:00:51,112 - __main__ - INFO - 从 ./files/raw_data/train_add.txt 成功加载 66879 条数据
2025-05-18 22:00:51,142 - __main__ - INFO - 成功从 ./files/raw_data/train_add.txt 加载 66879 条增强训练数据。
2025-05-18 22:00:51,143 - __main__ - INFO - 开始预处理增强训练数据...
2025-05-18 22:00:51,143 - __main__ - INFO - 发现预处理好的增强训练数据缓存，正在从磁盘加载: ./files/processed_data_cache/aug_train_dataset_cache
2025-05-18 22:00:51,149 - __main__ - INFO - 从磁盘加载预处理增强训练数据成功。
2025-05-18 22:00:51,150 - __main__ - INFO - 增强训练数据集格式已设置为 PyTorch Tensors。
2025-05-18 22:00:51,150 - __main__ - INFO - 创建增强训练数据的 DataLoader...
2025-05-18 22:00:51,151 - __main__ - INFO - 增强训练数据 DataLoader 创建成功 (Batch Size: 256, Num Workers: 0)。
2025-05-18 22:00:51,151 - 

Aug Epoch 1 Training:   0%|          | 0/262 [00:00<?, ?it/s]

2025-05-18 22:04:41,166 - __main__ - INFO - 开始验证... Aug Epoch 1, Step 262/262


Validation Aug Epoch 1, Step 262/262:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 22:06:07,358 - __main__ - INFO - 验证完成. 平均验证损失: 0.0322
2025-05-18 22:06:07,448 - __main__ - INFO -   验证准确率 (Accuracy): 0.9897
2025-05-18 22:06:07,554 - __main__ - INFO -   验证加权 F1-Score: 0.9897
2025-05-18 22:06:07,557 - __main__ - INFO - 🎉 新的最佳增强模型 (基于Accuracy)! Accuracy: 0.9897 at Aug Epoch 1, Step 262/262. 保存模型到 ./files/saved_models/best_augmented_model.pt
2025-05-18 22:06:09,297 - __main__ - INFO - 二次微调 Epoch 1 训练阶段完成. 平均训练损失: 0.0066
2025-05-18 22:06:09,298 - __main__ - INFO - 二次微调 Epoch 1 总耗时: 317.11 秒
2025-05-18 22:06:09,334 - __main__ - INFO - --- 二次微调 Epoch 2/3 ---


Aug Epoch 2 Training:   0%|          | 0/262 [00:00<?, ?it/s]

2025-05-18 22:09:59,621 - __main__ - INFO - 开始验证... Aug Epoch 2, Step 262/262


Validation Aug Epoch 2, Step 262/262:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 22:11:26,277 - __main__ - INFO - 验证完成. 平均验证损失: 0.0344
2025-05-18 22:11:26,368 - __main__ - INFO -   验证准确率 (Accuracy): 0.9896
2025-05-18 22:11:26,478 - __main__ - INFO -   验证加权 F1-Score: 0.9896
2025-05-18 22:11:26,481 - __main__ - INFO - 当前验证 Accuracy 0.9896 未超过二次微调历史最佳 0.9897 (at Aug Epoch 2, Step 262/262).
2025-05-18 22:11:26,483 - __main__ - INFO - 二次微调 Epoch 2 训练阶段完成. 平均训练损失: 0.0030
2025-05-18 22:11:26,484 - __main__ - INFO - 二次微调 Epoch 2 总耗时: 317.15 秒
2025-05-18 22:11:26,484 - __main__ - INFO - --- 二次微调 Epoch 3/3 ---


Aug Epoch 3 Training:   0%|          | 0/262 [00:00<?, ?it/s]

2025-05-18 22:15:16,680 - __main__ - INFO - 开始验证... Aug Epoch 3, Step 262/262


Validation Aug Epoch 3, Step 262/262:   0%|          | 0/313 [00:00<?, ?it/s]

2025-05-18 22:16:43,389 - __main__ - INFO - 验证完成. 平均验证损失: 0.0322
2025-05-18 22:16:43,482 - __main__ - INFO -   验证准确率 (Accuracy): 0.9903
2025-05-18 22:16:43,591 - __main__ - INFO -   验证加权 F1-Score: 0.9904
2025-05-18 22:16:43,594 - __main__ - INFO - 🎉 新的最佳增强模型 (基于Accuracy)! Accuracy: 0.9903 at Aug Epoch 3, Step 262/262. 保存模型到 ./files/saved_models/best_augmented_model.pt
2025-05-18 22:16:44,901 - __main__ - INFO - 二次微调 Epoch 3 训练阶段完成. 平均训练损失: 0.0010
2025-05-18 22:16:44,902 - __main__ - INFO - 二次微调 Epoch 3 总耗时: 318.42 秒
2025-05-18 22:16:44,902 - __main__ - INFO - 所有计划的二次微调 Epoch 训练与验证完成。
2025-05-18 22:16:44,903 - __main__ - INFO - 二次微调期间记录的最佳模型信息: Epoch=3, Step=262, Val Accuracy=0.9903, Val F1=0.9904, Val Loss=0.0322
2025-05-18 22:16:44,903 - __main__ - INFO - 最佳二次微调模型参数已保存至: ./files/saved_models/best_augmented_model.pt (对应验证准确率: 0.9903)
2025-05-18 22:16:44,904 - __main__ - INFO - 总二次微调耗时: 00小时 15分钟 52.72秒



## 最终测试集预测与结果分析(Final Test Set Prediction & Result Analysis)

In [None]:
# ==============================================================================
# 主要功能:
# 1. 检查前置条件，加载二次微调后的最佳模型权重。
# 2. 对整个测试集执行预测，获取每个样本的 top-1 预测标签及其概率。
# 3. 结果保存:
#    a. Top-1 预测结果 (原始ID, 文本, 预测标签, 概率) 保存为 JSON 文件。
#    b. 仅包含预测标签的文本文件 (result.txt)。
# 4. 统计预测概率分布 (与单元格11类似，但使用最终模型的预测结果):
#    a. 绘制并保存概率分布柱状图。
#    b. 将概率分布的具体数值保存为 CSV 文件。
# ==============================================================================
log.info("="*30 + " 开始最终测试集预测与结果分析 (使用二次微调模型) " + "="*30)

# --- 13.1 检查前置条件与配置加载 (Check Prerequisites & Load Configs) ---
required_vars_for_final_prediction = [
     'id_to_label', 'log', 'device', 'writer',
    'test_data_loader', 'raw_test_data',
    'BEST_AUG_MODEL_PATH', # 二次微调后的最佳模型路径
    'FINAL_RESULT_TXT_PATH',
    'FINAL_PROB_DIST_CSV_PATH', # 最终分析图表和CSV保存目录
    'FINAL_PROB_DIST_PLOT_PATH',
    'PROB_DIST_PLOT_STEP'  # 概率分布图的分度值
]
log.info("检查最终预测与分析所需的前置变量...")
for var_name in required_vars_for_final_prediction:
    if var_name not in locals() and var_name not in globals():
        log.critical(f"最终预测分析失败：关键前置变量 '{var_name}' 未定义。请检查之前的单元格。")
        exit(1)

if test_data_loader is None or len(test_data_loader) == 0:
    log.critical("测试数据加载器 (test_data_loader) 为空或未定义，无法进行最终预测。")
    exit(1)
if not raw_test_data:
    log.critical("原始测试数据 (raw_test_data) 为空，无法获取文本内容。")
    exit(1)

log.info("所有最终预测分析所需的前置变量均已定义。")



# --- 13.2 加载二次微调后的最佳模型权重 (Load Best Second Stage Fine-tuned Model Weights) ---
log.info(f"加载二次微调后的最佳模型权重从: {BEST_AUG_MODEL_PATH}")
if not os.path.exists(BEST_AUG_MODEL_PATH):
    log.warning(f"二次微调后的最佳模型文件 {BEST_AUG_MODEL_PATH} 不存在。")
    log.warning(f"将尝试使用当前内存中的 'model' 对象进行预测（可能是二次微调的最后状态）。")
    log.warning(f"如果希望使用磁盘上特定的最佳增强模型，请确保 {BEST_AUG_MODEL_PATH} 存在。")

    if model is None: #
        log.critical("当前 'model'为 None，且无法从磁盘加载二次微调模型。无法进行预测。")
        exit(1)
    model.to(device) 
    log.info("已准备使用当前内存中的模型进行预测。")
else: # BEST_AUG_MODEL_PATH 存在
    try:
        model.load_state_dict(torch.load(BEST_AUG_MODEL_PATH, map_location=device))
        model.to(device)
        log.info(f"二次微调后的最佳模型权重从 {BEST_AUG_MODEL_PATH} 加载成功。")
    except Exception as e:
        log.critical(f"加载二次微调模型权重 {BEST_AUG_MODEL_PATH} 失败: {e}", exc_info=True)
        exit(1)

# --- 13.3 执行全量预测并收集 Top-1 结果 (Perform Full Prediction & Collect Top-1 Results) ---
model.eval() 
final_all_top1_results = [] 

log.info("开始对所有测试数据进行最终 Top-1 预测 (使用二次微调模型)...")
final_predict_progress_bar = tqdm(test_data_loader, desc="Final Test Prediction (Top-1)", leave=False)

with torch.no_grad():
    for batch_idx, batch in enumerate(final_predict_progress_bar):
        if 'id' not in batch:
            log.warning(f"最终预测批次 {batch_idx} 中缺少 'id' 键，跳过此批次中的部分数据关联。")
            # 对于最终预测，原始ID对于结果的完整性很重要

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)

        top1_probs, top1_indices = torch.max(probabilities, dim=-1)
        
        # 获取原始ID，如果 'id' 不在batch中，则生成一个基于batch索引的伪ID（不推荐用于真实场景）
        batch_original_ids = batch.get('id')
        if batch_original_ids is None:
            log.error(f"最终预测批次 {batch_idx} 中 'id' 键缺失")
            exit(1)
        else:
            batch_original_ids_list = batch_original_ids.cpu().tolist()

        for i in range(top1_probs.size(0)):
            original_id = batch_original_ids_list[i]
            prob_val = top1_probs[i].item()
            label_idx = top1_indices[i].item()
            predicted_label_text = id_to_label.get(label_idx, f"未知标签ID:{label_idx}")

            original_text = "原始文本不可用" # 默认值
            try:
                # 确保 original_id 是有效的 raw_test_data 索引
                # 这个 original_id 应该与 raw_test_data 创建时的id对应
                if 0 <= original_id < len(raw_test_data) and isinstance(raw_test_data[original_id], dict) and 'text' in raw_test_data[original_id]:
                     original_text = raw_test_data[original_id]['text']
                else:
                    # 尝试通过查找匹配的 'id' 字段（如果 raw_test_data 中的 id 不是简单索引）
                    found_text = False
                    for item_idx, item in enumerate(raw_test_data): # 效率较低，但作为后备
                        if isinstance(item, dict) and item.get('id') == original_id:
                            original_text = item['text']
                            found_text = True
                            break
                    if not found_text:
                        log.warning(f"无法通过 original_id {original_id} (来自DataLoader) 在 raw_test_data 中找到对应文本。")
            except Exception as e_text:
                log.warning(f"通过 original_id {original_id} 获取原始文本时发生错误: {e_text}")
            
            final_all_top1_results.append({
                "original_id": original_id,
                "text": original_text,
                "predicted_label": predicted_label_text,
                "probability": round(prob_val, 3) # 保留三位小数
            })
final_predict_progress_bar.close()
log.info(f"最终 Top-1 预测处理完成，共获得 {len(final_all_top1_results)} 条预测结果。")

if not final_all_top1_results:
    log.critical("未能生成任何最终 Top-1 预测结果，后续分析无法进行。")
    exit(1)

# --- 13.4 保存最终预测结果 (Save Final Prediction Results) ---
# 13.4.1 保存 Top-1 预测 JSON 结果
try:
    with open(FINAL_TOP1_JSON_PATH, 'w', encoding='utf-8') as f:
        json.dump(final_all_top1_results, f, ensure_ascii=False, indent=2)
    log.info(f"最终测试集 Top-1 预测结果已保存至 JSON: {FINAL_TOP1_JSON_PATH}")
    if final_all_top1_results: # 打印示例
        log.info("JSON结果示例 (第一条):")
        log.info(json.dumps(final_all_top1_results[0], ensure_ascii=False, indent=2))
except Exception as e:
    log.error(f"保存最终 Top-1 预测 JSON 结果时发生错误: {e}", exc_info=True)

# 13.4.2 保存仅含标签的 result.txt
try:
    count_written_txt = 0
    with open(FINAL_RESULT_TXT_PATH, 'w', encoding='utf-8') as f_txt:
        for result_item in final_all_top1_results:
            f_txt.write(f"{result_item['predicted_label']}\n")
            count_written_txt +=1
    log.info(f"成功将 {count_written_txt} 条预测标签写入纯文本文件: {FINAL_RESULT_TXT_PATH}")
    if count_written_txt > 0:
        log.info("result.txt 内容示例 (前3行，如果存在):")
        with open(FINAL_RESULT_TXT_PATH, 'r', encoding='utf-8') as f_preview:
            for i, line in enumerate(f_preview):
                if i < 3:
                    log.info(f"  {line.strip()}")
                else:
                    break
except Exception as e:
    log.error(f"保存 result.txt 文件时发生错误: {e}", exc_info=True)


# --- 13.5 统计并绘制最终概率分布图 (Analyze and Plot Final Probability Distribution) ---
log.info("开始统计和绘制最终 Top-1 预测概率分布图...")
final_top1_probabilities = np.array([res['probability'] for res in final_all_top1_results])

if len(final_top1_probabilities) == 0:
    log.warning("没有最终 Top-1 概率数据可供分析，跳过概率分布图绘制。")
else:
    min_prob_final = np.min(final_top1_probabilities) if len(final_top1_probabilities) > 0 else 0.0
    final_prob_thresholds = np.arange(1.0, min_prob_final - PROB_DIST_PLOT_STEP, -PROB_DIST_PLOT_STEP)
    final_prob_thresholds = np.clip(final_prob_thresholds, 0.0, 1.0)
    final_prob_thresholds = np.unique(final_prob_thresholds)[::-1]
    if final_prob_thresholds[0] < 1.0:
        final_prob_thresholds = np.insert(final_prob_thresholds, 0, 1.0)

    final_sample_proportions = []
    for threshold in final_prob_thresholds:
        count_above_threshold = np.sum(final_top1_probabilities >= threshold)
        proportion = count_above_threshold / len(final_top1_probabilities) if len(final_top1_probabilities) > 0 else 0
        final_sample_proportions.append(proportion)

    # 绘制柱状图
    try:
        plt.figure(figsize=(12, 7))
        bar_positions_final = np.arange(len(final_prob_thresholds))
        plt.bar(bar_positions_final, final_sample_proportions, width=0.8, color='mediumseagreen') # 不同颜色

        plt.xlabel("Top-1 Prediction Probability Threshold (P) - Final Model")
        plt.ylabel(f"Proportion of Samples with Top-1 Prob >= P (Total Samples: {len(final_top1_probabilities)})")
        plt.title("Distribution of Top-1 Prediction Probabilities on Test Set (Final Fine-tuned Model)")
        
        tick_indices_final = np.linspace(0, len(final_prob_thresholds) - 1, num=min(15, len(final_prob_thresholds)), dtype=int)
        plt.xticks(bar_positions_final[tick_indices_final], [f"{final_prob_thresholds[i]:.3f}" for i in tick_indices_final], rotation=45, ha="right")
        
        plt.yticks(np.arange(0, 1.1, 0.1))
        plt.ylim(0, 1.05)
        plt.grid(axis='y', linestyle='--')
        plt.tight_layout()
        plt.savefig(FINAL_PROB_DIST_PLOT_PATH)
        log.info(f"最终概率分布柱状图已保存至: {FINAL_PROB_DIST_PLOT_PATH}")
        plt.close()

        # 将图像记录到TensorBoard
        try:
            image = Image.open(FINAL_PROB_DIST_PLOT_PATH)
            image_tensor_final = torch.tensor(np.array(image)).permute(2,0,1)
            writer.add_image('Analysis/FinalProbabilityDistribution', image_tensor_final, global_step=1) # global_step可以设为1或与之前区分
            log.info("最终概率分布图已尝试记录到 TensorBoard。")
        except Exception as e_tb_img_final:
            log.warning(f"记录最终概率分布图到 TensorBoard 失败: {e_tb_img_final}")

    except Exception as e_plot_final:
        log.error(f"绘制或保存最终概率分布图时发生错误: {e_plot_final}", exc_info=True)

    # 保存最终概率分布数据到CSV
    try:
        with open(FINAL_PROB_DIST_CSV_PATH, 'w', newline='', encoding='utf-8') as csvfile:
            csv_writer = csv.writer(csvfile)
            csv_writer.writerow(['Probability_Threshold', 'Proportion_Samples_Above_Threshold'])
            for threshold, proportion in zip(final_prob_thresholds, final_sample_proportions):
                csv_writer.writerow([f"{threshold:.3f}", f"{proportion:.4f}"])
        log.info(f"最终概率分布数据已保存至 CSV: {FINAL_PROB_DIST_CSV_PATH}")
    except Exception as e_csv_final:
        log.error(f"保存最终概率分布数据到 CSV 时发生错误: {e_csv_final}", exc_info=True)

log.info("="*30 + " 最终测试集预测与结果分析结束 " + "="*30 + "\n")

writer.close()

2025-05-18 22:20:43,707 - __main__ - INFO - 检查最终预测与分析所需的前置变量...
2025-05-18 22:20:43,708 - __main__ - INFO - 所有最终预测分析所需的前置变量均已定义。
2025-05-18 22:20:43,708 - __main__ - INFO - 加载二次微调后的最佳模型权重从: ./files/saved_models/best_augmented_model.pt
  model.load_state_dict(torch.load(BEST_AUG_MODEL_PATH, map_location=device))
2025-05-18 22:20:44,468 - __main__ - INFO - 二次微调后的最佳模型权重从 ./files/saved_models/best_augmented_model.pt 加载成功。
2025-05-18 22:20:44,479 - __main__ - INFO - 开始对所有测试数据进行最终 Top-1 预测 (使用二次微调模型)...


Final Test Prediction (Top-1):   0%|          | 0/327 [00:00<?, ?it/s]

2025-05-18 22:22:14,496 - __main__ - INFO - 最终 Top-1 预测处理完成，共获得 83599 条预测结果。
2025-05-18 22:22:14,934 - __main__ - INFO - 最终测试集 Top-1 预测结果已保存至 JSON: ./files/results/final_test_top1_predictions.json
2025-05-18 22:22:14,935 - __main__ - INFO - JSON结果示例 (第一条):
2025-05-18 22:22:14,935 - __main__ - INFO - {
  "original_id": 0,
  "text": "北京君太百货璀璨秋色 满100省353020元",
  "predicted_label": "房产",
  "probability": 0.996
}
2025-05-18 22:22:14,968 - __main__ - INFO - 成功将 83599 条预测标签写入纯文本文件: ./files/results/result.txt
2025-05-18 22:22:14,968 - __main__ - INFO - result.txt 内容示例 (前3行，如果存在):
2025-05-18 22:22:14,969 - __main__ - INFO -   房产
2025-05-18 22:22:14,969 - __main__ - INFO -   时政
2025-05-18 22:22:14,969 - __main__ - INFO -   科技
2025-05-18 22:22:14,970 - __main__ - INFO - 开始统计和绘制最终 Top-1 预测概率分布图...
