In [5]:
# 直接在代码中运行
input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl"
results = evaluate_model(input_file)

正在加载数据: /content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl
加载了 548 条数据
未能提取任何预测结果


In [13]:
import json
import os
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any, Tuple

class EvaluationMetrics:
    """计算各种NLP评估指标的类"""

    @staticmethod
    def normalize_answer(s: str) -> str:
        """标准化答案文本"""
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))

    @staticmethod
    def exact_match(prediction: str, ground_truth: str) -> float:
        """计算精确匹配(EM)分数"""
        return float(EvaluationMetrics.normalize_answer(prediction) ==
                    EvaluationMetrics.normalize_answer(ground_truth))

    @staticmethod
    def f1_score(prediction: str, ground_truth: str) -> float:
        """计算F1分数"""
        pred_tokens = EvaluationMetrics.normalize_answer(prediction).split()
        truth_tokens = EvaluationMetrics.normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        if precision + recall == 0:
            return 0.0

        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    @staticmethod
    def anls_score(prediction: str, ground_truth: str, threshold: float = 0.5) -> float:
        """计算ANLS分数"""
        def levenshtein_distance(s1: str, s2: str) -> int:
            if len(s1) < len(s2):
                return levenshtein_distance(s2, s1)

            if len(s2) == 0:
                return len(s1)

            previous_row = list(range(len(s2) + 1))
            for i, c1 in enumerate(s1):
                current_row = [i + 1]
                for j, c2 in enumerate(s2):
                    insertions = previous_row[j + 1] + 1
                    deletions = current_row[j] + 1
                    substitutions = previous_row[j] + (c1 != c2)
                    current_row.append(min(insertions, deletions, substitutions))
                previous_row = current_row

            return previous_row[-1]

        pred_norm = EvaluationMetrics.normalize_answer(prediction)
        truth_norm = EvaluationMetrics.normalize_answer(ground_truth)

        if len(truth_norm) == 0:
            return 1.0 if len(pred_norm) == 0 else 0.0

        edit_distance = levenshtein_distance(pred_norm, truth_norm)
        max_len = max(len(pred_norm), len(truth_norm))

        if max_len == 0:
            return 1.0

        normalized_distance = edit_distance / max_len
        similarity = 1 - normalized_distance

        return similarity if similarity >= threshold else 0.0

    @staticmethod
    def reranker_score(predictions: List[str], ground_truths: List[str], k: int = 1) -> float:
        """计算Reranker分数"""
        if not predictions or not ground_truths:
            return 0.0

        top_k_preds = predictions[:k] if len(predictions) >= k else predictions

        for pred in top_k_preds:
            for truth in ground_truths:
                if EvaluationMetrics.exact_match(pred, truth):
                    return 1.0

        return 0.0

def load_data(input_file: str) -> List[Dict[str, Any]]:
    """加载数据文件"""
    data = []

    if not os.path.exists(input_file):
        print(f"文件不存在: {input_file}")
        return data

    if input_file.endswith('.json'):
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
    elif input_file.endswith('.jsonl'):
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
    else:
        if os.path.isdir(input_file):
            for filename in os.listdir(input_file):
                if filename.endswith(('.json', '.jsonl')):
                    filepath = os.path.join(input_file, filename)
                    if filename.endswith('.json'):
                        with open(filepath, 'r', encoding='utf-8') as f:
                            file_data = json.load(f)
                            if isinstance(file_data, list):
                                data.extend(file_data)
                            else:
                                data.append(file_data)
                    else:
                        with open(filepath, 'r', encoding='utf-8') as f:
                            for line in f:
                                if line.strip():
                                    data.append(json.loads(line.strip()))

    return data

def evaluate_model(input_file: str, output_file: str = None) -> Dict[str, float]:
    """评估模型性能"""
    print(f"正在加载数据: {input_file}")
    data = load_data(input_file)

    if not data:
        print("未能加载任何数据")
        return {}

    print(f"加载了 {len(data)} 条数据")

    # 检查数据格式
    print("\n=== 数据格式检查 ===")
    sample_item = data[0]
    if isinstance(sample_item, dict):
        print("第一个样本的字段:")
        for key in sample_item.keys():
            value = sample_item[key]
            if isinstance(value, str):
                preview = value[:50] + "..." if len(value) > 50 else value
                print(f"  {key}: '{preview}'")

    # 提取预测结果和真实标签
    predictions = []
    ground_truths = []

    for item in data:
        if isinstance(item, dict):
            # 根据您的数据格式，使用 predict 和 label 字段
            if 'predict' in item and 'label' in item:
                predictions.append(str(item['predict']))
                ground_truths.append(str(item['label']))

    if not predictions:
        print("未能提取任何预测结果")
        return {}

    print(f"成功提取 {len(predictions)} 条样本")
    print(f"预测样例: {predictions[0][:100]}...")
    print(f"真实标签样例: {ground_truths[0][:100]}...")

    print(f"\n正在计算评估指标...")

    # 计算各种指标
    em_scores = []
    f1_scores = []
    anls_scores = []
    reranker_scores = []

    for pred, truth in zip(predictions, ground_truths):
        em_scores.append(EvaluationMetrics.exact_match(pred, truth))
        f1_scores.append(EvaluationMetrics.f1_score(pred, truth))
        anls_scores.append(EvaluationMetrics.anls_score(pred, truth))
        reranker_scores.append(EvaluationMetrics.reranker_score([pred], [truth]))

    # 计算平均分数
    results = {
        'EM': np.mean(em_scores),
        'F1': np.mean(f1_scores),
        'ANLS': np.mean(anls_scores),
        'Reranker': np.mean(reranker_scores),
        'sample_count': len(predictions)
    }

    # 打印结果
    print("\n=== 评估结果 ===")
    print(f"样本数量: {results['sample_count']}")
    print(f"EM Score: {results['EM']:.4f} ({results['EM']*100:.2f}%)")
    print(f"F1 Score: {results['F1']:.4f} ({results['F1']*100:.2f}%)")
    print(f"ANLS Score: {results['ANLS']:.4f} ({results['ANLS']*100:.2f}%)")
    print(f"Reranker Score: {results['Reranker']:.4f} ({results['Reranker']*100:.2f}%)")

    # 保存结果
    if output_file:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print(f"\n结果已保存到: {output_file}")

    return results

# 直接运行评估
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl"
    results = evaluate_model(input_file, "evaluation_results.json")

正在加载数据: /content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl
加载了 548 条数据

=== 数据格式检查 ===
第一个样本的字段:
  prompt: '<|im_start|>system
You are a helpful assistant.<|i...'
  predict: 'Tony
Tonyassistant
Tony Robbins, in his book "The ...'
  label: 'Tony Robbins describes six core human needs that d...'
成功提取 548 条样本
预测样例: Tony
Tonyassistant
Tony Robbins, in his book "The 6 Human Needs," explains the 6 human needs that ar...
真实标签样例: Tony Robbins describes six core human needs that drive our behaviors and motivations. These six need...

正在计算评估指标...

=== 评估结果 ===
样本数量: 548
EM Score: 0.0000 (0.00%)
F1 Score: 0.2642 (26.42%)
ANLS Score: 0.0282 (2.82%)
Reranker Score: 0.0000 (0.00%)

结果已保存到: evaluation_results.json


In [12]:
import json
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any

class PredictionCleaner:
    """用于清理和改进预测结果的类"""

    @staticmethod
    def remove_repetitions(text: str, max_repeat: int = 3) -> str:
        """移除过度重复的词汇"""
        words = text.split()
        if not words:
            return text

        cleaned_words = []
        word_count = Counter()

        for word in words:
            word_lower = word.lower().strip('.,!?;:')
            word_count[word_lower] += 1

            # 如果这个词已经重复太多次，跳过
            if word_count[word_lower] <= max_repeat:
                cleaned_words.append(word)
            elif len(cleaned_words) == 0:  # 保留至少一个词
                cleaned_words.append(word)

        return ' '.join(cleaned_words)

    @staticmethod
    def remove_format_tokens(text: str) -> str:
        """移除格式标记和错误的系统词汇"""
        # 移除chat格式标记
        text = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', text, flags=re.DOTALL)
        text = re.sub(r'<\|.*?\|>', '', text)

        # 移除assistant相关词汇
        text = re.sub(r'\b(assistant|system|user)\b', '', text, flags=re.IGNORECASE)

        # 移除开头的重复名字（如 "Tony Tonyassistant"）
        text = re.sub(r'^(\w+)\s*\1\s*(assistant|system)?\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 移除其他常见错误格式
        text = re.sub(r'\bhelpful assistant\b', '', text, flags=re.IGNORECASE)

        return text.strip()

    @staticmethod
    def remove_sentence_repetitions(text: str) -> str:
        """移除重复的句子"""
        sentences = re.split(r'[.!?]+', text)
        seen_sentences = set()
        unique_sentences = []

        for sentence in sentences:
            sentence_clean = sentence.strip().lower()
            if sentence_clean and sentence_clean not in seen_sentences:
                seen_sentences.add(sentence_clean)
                unique_sentences.append(sentence.strip())

        return '. '.join(s for s in unique_sentences if s) + ('.' if unique_sentences else '')

    @staticmethod
    def fix_common_errors(text: str) -> str:
        """修复常见错误"""
        # 修复粘连的词（如 "Tonyassistant" -> "Tony"）
        text = re.sub(r'(\w+)assistant', r'\1', text, flags=re.IGNORECASE)
        text = re.sub(r'(\w+)system', r'\1', text, flags=re.IGNORECASE)

        # 修复多余的换行和空格
        text = re.sub(r'\n+', ' ', text)
        text = re.sub(r'\s+', ' ', text)

        # 修复开头的错误（移除开头重复的词）
        words = text.split()
        if len(words) > 1 and words[0].lower() == words[1].lower():
            text = ' '.join(words[1:])

        return text.strip()

    @classmethod
    def clean_prediction(cls, text: str) -> str:
        """综合清理预测文本"""
        # 按步骤清理
        text = cls.remove_format_tokens(text)
        text = cls.fix_common_errors(text)
        text = cls.remove_repetitions(text, max_repeat=2)  # 更严格的重复控制
        text = cls.remove_sentence_repetitions(text)

        return text.strip()

def clean_and_evaluate(input_file: str, output_file: str = None):
    """清理预测并重新评估"""

    # 加载数据
    print("正在加载数据...")
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    print(f"加载了 {len(data)} 条数据")

    # 清理预测
    print("正在清理预测结果...")
    cleaned_data = []
    improvement_count = 0

    for i, item in enumerate(data):
        original_pred = item['predict']
        cleaned_pred = PredictionCleaner.clean_prediction(original_pred)

        # 检查是否有改进
        if cleaned_pred != original_pred:
            improvement_count += 1

        cleaned_item = item.copy()
        cleaned_item['original_predict'] = original_pred
        cleaned_item['predict'] = cleaned_pred
        cleaned_data.append(cleaned_item)

        # 显示前几个清理示例
        if i < 5:
            print(f"\n--- 样本 {i+1} 清理对比 ---")
            print(f"原始: {original_pred[:150]}...")
            print(f"清理后: {cleaned_pred[:150]}...")

    print(f"\n改进了 {improvement_count} 个样本 ({improvement_count/len(data)*100:.1f}%)")

    # 重新计算评估指标
    print("\n正在重新计算评估指标...")

    # 评估函数
    def normalize_answer(s: str) -> str:
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def exact_match(prediction: str, ground_truth: str) -> float:
        return float(normalize_answer(prediction) == normalize_answer(ground_truth))

    def f1_score(prediction: str, ground_truth: str) -> float:
        pred_tokens = normalize_answer(prediction).split()
        truth_tokens = normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        if precision + recall == 0:
            return 0.0

        return (2 * precision * recall) / (precision + recall)

    def partial_match_score(prediction: str, ground_truth: str, threshold: float = 0.7) -> float:
        """计算部分匹配分数，用于改进的Reranker"""
        f1 = f1_score(prediction, ground_truth)
        return 1.0 if f1 >= threshold else 0.0

    # 计算原始分数和清理后分数
    original_em_scores = []
    original_f1_scores = []
    original_reranker_scores = []

    cleaned_em_scores = []
    cleaned_f1_scores = []
    cleaned_reranker_scores = []

    for item in cleaned_data:
        original_pred = item['original_predict']
        cleaned_pred = item['predict']
        label = item['label']

        # 原始分数
        original_em_scores.append(exact_match(original_pred, label))
        original_f1_scores.append(f1_score(original_pred, label))
        original_reranker_scores.append(partial_match_score(original_pred, label, 0.5))

        # 清理后分数
        cleaned_em_scores.append(exact_match(cleaned_pred, label))
        cleaned_f1_scores.append(f1_score(cleaned_pred, label))
        cleaned_reranker_scores.append(partial_match_score(cleaned_pred, label, 0.5))

    # 计算平均分数
    original_results = {
        'EM': np.mean(original_em_scores),
        'F1': np.mean(original_f1_scores),
        'Reranker': np.mean(original_reranker_scores)
    }

    cleaned_results = {
        'EM': np.mean(cleaned_em_scores),
        'F1': np.mean(cleaned_f1_scores),
        'Reranker': np.mean(cleaned_reranker_scores)
    }

    # 打印对比结果
    print("\n=== 清理前后对比 ===")
    print(f"样本数量: {len(data)}")
    print(f"\n{'指标':<15} {'原始':<10} {'清理后':<10} {'改进':<10}")
    print("-" * 50)

    for metric in ['EM', 'F1', 'Reranker']:
        original = original_results[metric]
        cleaned = cleaned_results[metric]
        improvement = cleaned - original
        print(f"{metric:<15} {original:<10.4f} {cleaned:<10.4f} {improvement:>+.4f}")

    # 找出改进最明显的样本
    print("\n=== 改进最明显的样本 ===")
    improvements = []
    for i, item in enumerate(cleaned_data):
        original_f1 = f1_score(item['original_predict'], item['label'])
        cleaned_f1 = f1_score(item['predict'], item['label'])
        improvement = cleaned_f1 - original_f1
        if improvement > 0.1:  # 改进超过0.1的样本
            improvements.append((i, improvement, item))

    improvements.sort(key=lambda x: x[1], reverse=True)
    for i, (idx, improvement, item) in enumerate(improvements[:3]):
        print(f"\n样本 {idx} (改进 +{improvement:.3f}):")
        print(f"原始: {item['original_predict'][:100]}...")
        print(f"清理: {item['predict'][:100]}...")
        print(f"标签: {item['label'][:100]}...")

    # 保存清理后的数据
    if output_file:
        print(f"\n正在保存清理后的数据到: {output_file}")
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in cleaned_data:
                # 只保存清理后的predict，移除original_predict以节省空间
                save_item = {k: v for k, v in item.items() if k != 'original_predict'}
                f.write(json.dumps(save_item, ensure_ascii=False) + '\n')

    return original_results, cleaned_results

# 运行清理和评估
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl"
    output_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/cleaned_predictions.jsonl"

    original_results, cleaned_results = clean_and_evaluate(input_file, output_file)

正在加载数据...
加载了 548 条数据
正在清理预测结果...

--- 样本 1 清理对比 ---
原始: Tony
Tonyassistant
Tony Robbins, in his book "The 6 Human Needs," explains the 6 human needs that are at the core of our behavior and motivation. Thes...
清理后: Tony Robbins, in his book "The 6 Human Needs," explains the 6 human needs that are at the core of our behavior and motivation. These needs are: 1. Nee...

--- 样本 2 清理对比 ---
原始: Here are the 3 key questions you should be asking yourself when you are deciding which of the two approaches to using customer segmentation is right f...
清理后: Here are the 3 key questions you should be asking yourself when you are deciding which of the two approaches to using customer segmentation is right f...

--- 样本 3 清理对比 ---
原始: ToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToTo...
清理后: ToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToTo

In [14]:
import json
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any, Tuple

class SmartCleaner:
    """更智能的预测清理器"""

    @staticmethod
    def detect_severe_repetition(text: str) -> bool:
        """检测严重的重复问题"""
        words = text.split()
        if len(words) < 5:
            return False

        # 检查是否有词汇重复超过总词数的50%
        word_counts = Counter(words)
        max_count = max(word_counts.values()) if word_counts else 0
        repetition_ratio = max_count / len(words)

        return repetition_ratio > 0.5

    @staticmethod
    def is_garbage_output(text: str) -> bool:
        """判断是否为垃圾输出"""
        # 检查各种垃圾输出模式
        text_lower = text.lower().strip()

        # 完全重复的单词
        words = text.split()
        if len(words) > 10:
            unique_words = set(words)
            if len(unique_words) <= 3:  # 只有很少的不同词汇
                return True

        # 检查是否主要是重复字符
        if len(text) > 50:
            char_counts = Counter(text.replace(' ', '').replace('\n', ''))
            if char_counts and max(char_counts.values()) > len(text) * 0.4:
                return True

        # 检查是否包含明显的格式错误
        error_patterns = [
            r'^(\w{1,4})\1{10,}',  # 短词重复很多次
            r'(assistant|system|user)\s*\1',  # 系统词重复
        ]

        for pattern in error_patterns:
            if re.search(pattern, text_lower):
                return True

        return False

    @staticmethod
    def smart_repetition_fix(text: str) -> str:
        """智能修复重复问题"""
        # 修复开头的重复问题
        text = re.sub(r'^(\w+)\s*\1\s*(assistant|system)?\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 修复句子级别的重复
        sentences = re.split(r'([.!?]+)', text)
        cleaned_sentences = []
        seen_sentences = set()

        for i in range(0, len(sentences), 2):
            if i < len(sentences):
                sentence = sentences[i].strip()
                sentence_key = re.sub(r'\s+', ' ', sentence.lower())

                if sentence_key and sentence_key not in seen_sentences and len(sentence) > 10:
                    seen_sentences.add(sentence_key)
                    cleaned_sentences.append(sentence)
                    if i + 1 < len(sentences):
                        cleaned_sentences.append(sentences[i + 1])  # 保留标点

        result = ''.join(cleaned_sentences).strip()

        # 如果清理后太短，使用原文
        if len(result) < len(text) * 0.3:
            return text

        return result

    @staticmethod
    def conservative_clean(text: str) -> str:
        """保守的清理策略"""
        original_text = text

        # 只修复明显的问题
        # 1. 移除格式标记
        text = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', text, flags=re.DOTALL)

        # 2. 修复开头的重复词和格式错误
        text = re.sub(r'^(\w+)\s*\1\s*assistant\s*', r'\1 ', text, flags=re.IGNORECASE)
        text = re.sub(r'^(\w+)\s*assistant\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 3. 移除standalone的assistant/system/user
        text = re.sub(r'\b(assistant|system|user)\s+', '', text, flags=re.IGNORECASE)

        # 4. 修复多余空格
        text = re.sub(r'\s+', ' ', text).strip()

        # 如果清理后明显变差，返回原文
        if len(text) < len(original_text) * 0.7:
            return original_text

        return text

    @classmethod
    def process_prediction(cls, text: str) -> Tuple[str, str]:
        """
        处理预测文本
        返回: (清理后的文本, 处理策略)
        """
        # 检查是否为严重的垃圾输出
        if cls.is_garbage_output(text):
            # 对于垃圾输出，尝试从中提取有意义的部分
            words = text.split()
            if words:
                # 尝试找到第一个正常的词开始
                meaningful_start = 0
                for i, word in enumerate(words):
                    if len(word) > 2 and word.lower() not in ['to', 'the', 'and', 'or', 'but']:
                        meaningful_start = i
                        break

                # 取前面一些正常词汇
                if meaningful_start < len(words) - 5:
                    reconstructed = ' '.join(words[meaningful_start:meaningful_start + 20])
                    return reconstructed, "garbage_reconstruction"

            return text, "garbage_kept"

        # 检查严重重复
        elif cls.detect_severe_repetition(text):
            cleaned = cls.smart_repetition_fix(text)
            return cleaned, "repetition_fix"

        # 保守清理
        else:
            cleaned = cls.conservative_clean(text)
            return cleaned, "conservative"

def advanced_evaluation(input_file: str, output_file: str = None):
    """高级评估和清理"""

    # 加载数据
    print("正在加载数据...")
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    print(f"加载了 {len(data)} 条数据")

    # 分析和清理
    print("正在分析和清理预测...")

    strategy_counts = Counter()
    cleaned_data = []
    severe_issues = []

    for i, item in enumerate(data):
        original_pred = item['predict']
        cleaned_pred, strategy = SmartCleaner.process_prediction(original_pred)

        strategy_counts[strategy] += 1

        # 记录严重问题的样本
        if strategy in ["garbage_reconstruction", "garbage_kept"]:
            severe_issues.append({
                'index': i,
                'strategy': strategy,
                'original': original_pred[:200] + "..." if len(original_pred) > 200 else original_pred,
                'cleaned': cleaned_pred[:200] + "..." if len(cleaned_pred) > 200 else cleaned_pred,
                'label': item['label'][:200] + "..." if len(item['label']) > 200 else item['label']
            })

        item_copy = item.copy()
        item_copy['original_predict'] = original_pred
        item_copy['predict'] = cleaned_pred
        item_copy['clean_strategy'] = strategy
        cleaned_data.append(item_copy)

    print(f"\n清理策略统计:")
    for strategy, count in strategy_counts.items():
        print(f"  {strategy}: {count} 个样本 ({count/len(data)*100:.1f}%)")

    # 显示严重问题样本
    if severe_issues:
        print(f"\n=== 发现 {len(severe_issues)} 个严重问题样本 ===")
        for i, issue in enumerate(severe_issues[:3]):
            print(f"\n样本 {issue['index']} ({issue['strategy']}):")
            print(f"原始: {issue['original']}")
            print(f"清理: {issue['cleaned']}")
            print(f"标签: {issue['label']}")

    # 重新评估
    print("\n正在重新计算评估指标...")

    def normalize_answer(s: str) -> str:
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def exact_match(prediction: str, ground_truth: str) -> float:
        return float(normalize_answer(prediction) == normalize_answer(ground_truth))

    def f1_score(prediction: str, ground_truth: str) -> float:
        pred_tokens = normalize_answer(prediction).split()
        truth_tokens = normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        if num_same == 0:
            return 0.0

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        return (2 * precision * recall) / (precision + recall)

    def smart_reranker(prediction: str, ground_truth: str) -> float:
        """智能Reranker，使用多个标准"""
        # 标准1: F1分数高于阈值
        f1 = f1_score(prediction, ground_truth)
        if f1 >= 0.6:
            return 1.0

        # 标准2: 关键词匹配
        pred_words = set(normalize_answer(prediction).split())
        truth_words = set(normalize_answer(ground_truth).split())

        if len(truth_words) > 0:
            keyword_overlap = len(pred_words & truth_words) / len(truth_words)
            if keyword_overlap >= 0.4:
                return 1.0

        return 0.0

    # 计算分数
    strategies = ['conservative', 'repetition_fix', 'garbage_reconstruction', 'garbage_kept']
    results_by_strategy = {}

    for strategy in strategies:
        strategy_items = [item for item in cleaned_data if item['clean_strategy'] == strategy]
        if not strategy_items:
            continue

        em_scores = []
        f1_scores = []
        reranker_scores = []

        for item in strategy_items:
            pred = item['predict']
            label = item['label']

            em_scores.append(exact_match(pred, label))
            f1_scores.append(f1_score(pred, label))
            reranker_scores.append(smart_reranker(pred, label))

        results_by_strategy[strategy] = {
            'count': len(strategy_items),
            'EM': np.mean(em_scores),
            'F1': np.mean(f1_scores),
            'Reranker': np.mean(reranker_scores)
        }

    # 计算总体分数
    all_em_scores = []
    all_f1_scores = []
    all_reranker_scores = []
    all_original_f1_scores = []

    for item in cleaned_data:
        pred = item['predict']
        orig_pred = item['original_predict']
        label = item['label']

        all_em_scores.append(exact_match(pred, label))
        all_f1_scores.append(f1_score(pred, label))
        all_reranker_scores.append(smart_reranker(pred, label))
        all_original_f1_scores.append(f1_score(orig_pred, label))

    overall_results = {
        'EM': np.mean(all_em_scores),
        'F1': np.mean(all_f1_scores),
        'Reranker': np.mean(all_reranker_scores)
    }

    original_f1 = np.mean(all_original_f1_scores)

    # 打印结果
    print(f"\n=== 智能清理结果 ===")
    print(f"总体改进:")
    print(f"  原始F1: {original_f1:.4f}")
    print(f"  清理后F1: {overall_results['F1']:.4f} (改进: {overall_results['F1'] - original_f1:+.4f})")
    print(f"  EM: {overall_results['EM']:.4f}")
    print(f"  智能Reranker: {overall_results['Reranker']:.4f}")

    print(f"\n按策略分类的结果:")
    for strategy, results in results_by_strategy.items():
        print(f"  {strategy} ({results['count']} 样本):")
        print(f"    EM: {results['EM']:.4f}, F1: {results['F1']:.4f}, Reranker: {results['Reranker']:.4f}")

    # 保存结果
    if output_file:
        print(f"\n正在保存清理后的数据到: {output_file}")
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in cleaned_data:
                # 保存清理后的数据，包含策略信息
                save_item = {
                    'prompt': item['prompt'],
                    'predict': item['predict'],
                    'label': item['label'],
                    'clean_strategy': item['clean_strategy']
                }
                f.write(json.dumps(save_item, ensure_ascii=False) + '\n')

    return overall_results, results_by_strategy

# 运行智能清理
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl"
    output_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/smart_cleaned_predictions.jsonl"

    overall_results, strategy_results = advanced_evaluation(input_file, output_file)

正在加载数据...
加载了 548 条数据
正在分析和清理预测...

清理策略统计:
  conservative: 528 个样本 (96.4%)
  garbage_kept: 7 个样本 (1.3%)
  garbage_reconstruction: 9 个样本 (1.6%)
  repetition_fix: 4 个样本 (0.7%)

=== 发现 16 个严重问题样本 ===

样本 2 (garbage_kept):
原始: ToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToTo...
清理: ToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToToTo...
标签: You can use the `String.format()` method in Java to replace placeholders in a string with values from a map. Here's an example code snippet that demonstrates how you can achieve this:
```java
import j...

样本 27 (garbage_reconstruction):
原始: Sure, I can help you implement a netfilter in Linux using the Linux kernel module. Here's an example of how you can do 

lora

In [17]:
import json
import numpy as np
from collections import Counter, defaultdict

def analyze_failure_patterns(input_file: str):
    """分析失败模式，给出针对性建议"""

    # 加载清理后的数据
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    print("=== 失败模式分析 ===")

    # 按策略分组分析
    strategy_groups = defaultdict(list)
    for item in data:
        strategy_groups[item['clean_strategy']].append(item)

    print(f"\n各策略的详细分析:")

    # 分析垃圾输出的原因
    garbage_samples = strategy_groups['garbage_reconstruction'] + strategy_groups['garbage_kept']

    if garbage_samples:
        print(f"\n🔍 垃圾输出分析 ({len(garbage_samples)} 样本):")

        garbage_patterns = Counter()
        for item in garbage_samples:
            pred = item['predict']

            # 分析垃圾类型
            if 'Sure' in pred and pred.count('Sure') > 5:
                garbage_patterns['Sure_repetition'] += 1
            elif len(set(pred.split())) <= 3:
                garbage_patterns['extreme_repetition'] += 1
            elif 'e-' in pred and pred.count('e-') > 3:
                garbage_patterns['number_spam'] += 1
            else:
                garbage_patterns['other'] += 1

        for pattern, count in garbage_patterns.items():
            print(f"  {pattern}: {count} 样本")

    # 分析conservative策略中的低分样本
    conservative_samples = strategy_groups['conservative']

    def calculate_f1(pred, label):
        import re, string
        from collections import Counter

        def normalize(s):
            s = re.sub(r'\s+', ' ', s.lower())
            s = ''.join(c for c in s if c not in string.punctuation)
            return s.split()

        pred_tokens = normalize(pred)
        label_tokens = normalize(label)

        if not pred_tokens and not label_tokens:
            return 1.0
        if not pred_tokens or not label_tokens:
            return 0.0

        common = Counter(pred_tokens) & Counter(label_tokens)
        num_same = sum(common.values())

        if num_same == 0:
            return 0.0

        precision = num_same / len(pred_tokens)
        recall = num_same / len(label_tokens)
        return (2 * precision * recall) / (precision + recall)

    # 计算每个样本的F1分数
    conservative_f1_scores = []
    for item in conservative_samples:
        f1 = calculate_f1(item['predict'], item['label'])
        conservative_f1_scores.append((f1, item))

    conservative_f1_scores.sort(key=lambda x: x[0])

    print(f"\n📊 Conservative策略分析 ({len(conservative_samples)} 样本):")

    low_f1_samples = [item for f1, item in conservative_f1_scores if f1 < 0.1]
    medium_f1_samples = [item for f1, item in conservative_f1_scores if 0.1 <= f1 < 0.3]
    good_f1_samples = [item for f1, item in conservative_f1_scores if f1 >= 0.3]

    print(f"  低分样本 (F1<0.1): {len(low_f1_samples)} ({len(low_f1_samples)/len(conservative_samples)*100:.1f}%)")
    print(f"  中等样本 (0.1≤F1<0.3): {len(medium_f1_samples)} ({len(medium_f1_samples)/len(conservative_samples)*100:.1f}%)")
    print(f"  良好样本 (F1≥0.3): {len(good_f1_samples)} ({len(good_f1_samples)/len(conservative_samples)*100:.1f}%)")

    # 分析低分样本的问题
    if low_f1_samples:
        print(f"\n🔍 低分样本问题分析:")

        problem_patterns = Counter()
        for item in low_f1_samples:
            pred = item['predict']
            label = item['label']

            # 长度问题
            if len(pred) < len(label) * 0.3:
                problem_patterns['too_short'] += 1
            elif len(pred) > len(label) * 3:
                problem_patterns['too_long'] += 1

            # 主题不匹配
            pred_words = set(pred.lower().split())
            label_words = set(label.lower().split())
            overlap = len(pred_words & label_words) / max(len(label_words), 1)

            if overlap < 0.1:
                problem_patterns['topic_mismatch'] += 1

            # 格式问题
            if 'here are' in pred.lower() and 'here' not in label.lower():
                problem_patterns['format_mismatch'] += 1

        for pattern, count in problem_patterns.items():
            print(f"  {pattern}: {count} 样本 ({count/len(low_f1_samples)*100:.1f}%)")

    return conservative_f1_scores, garbage_samples

def generate_improvement_plan(conservative_f1_scores, garbage_samples):
    """生成改进计划"""

    print(f"\n🎯 === 改进建议 ===")

    # 短期改进（后处理优化）
    print(f"\n📝 短期改进（立即可行）:")

    print(f"1. **优化后处理管道**:")
    print(f"   - 当前Reranker已达到25%，表现不错")
    print(f"   - 可以调整智能Reranker的阈值来平衡精确度和召回率")
    print(f"   - 对96%的良好样本，可以进一步微调清理规则")

    print(f"\n2. **处理垃圾样本**:")
    if garbage_samples:
        print(f"   - {len(garbage_samples)} 个垃圾样本可以考虑直接过滤或用规则生成")
        print(f"   - 这些样本可能来自训练数据中的问题样本")

    # 中期改进（重新训练）
    print(f"\n🔄 中期改进（重新训练）:")

    print(f"1. **数据质量优化**:")
    print(f"   - 清理训练数据中可能导致重复输出的样本")
    print(f"   - 增加数据多样性，特别是针对低分样本的类似场景")

    print(f"2. **训练参数调优**:")
    print(f"   - 增加repetition_penalty (建议1.1-1.2)")
    print(f"   - 调整temperature (建议0.7-0.9)")
    print(f"   - 考虑使用top_p采样 (0.8-0.95)")

    print(f"3. **训练策略改进**:")
    print(f"   - 使用更好的停止条件")
    print(f"   - 考虑添加更多轮次的训练")
    print(f"   - 可以尝试DPO (Direct Preference Optimization)")

    # 长期改进
    print(f"\n🚀 长期改进（架构升级）:")
    print(f"1. 考虑使用更大的模型或更新的架构")
    print(f"2. 实施RLHF (人类反馈强化学习)")
    print(f"3. 使用多轮对话训练提升一致性")

    # 实用的评估改进
    print(f"\n📊 评估改进建议:")
    print(f"1. **当前Reranker 25%已经很好**，可以考虑:")
    print(f"   - 调整阈值获得更高分数")
    print(f"   - 使用语义相似度 (如sentence-transformers)")
    print(f"   - 实施多级评估标准")

    print(f"\n2. **EM分数为0是正常的**，因为:")
    print(f"   - 生成任务很少有完全匹配")
    print(f"   - F1和Reranker分数更有意义")
    print(f"   - 您的模型实际表现比EM显示的要好")

def create_filtered_dataset(input_file: str, output_file: str, min_f1: float = 0.15):
    """创建过滤后的高质量数据集用于进一步训练"""

    import re, string
    from collections import Counter

    def calculate_f1(pred, label):
        def normalize(s):
            s = re.sub(r'\s+', ' ', s.lower())
            s = ''.join(c for c in s if c not in string.punctuation)
            return s.split()

        pred_tokens = normalize(pred)
        label_tokens = normalize(label)

        if not pred_tokens and not label_tokens:
            return 1.0
        if not pred_tokens or not label_tokens:
            return 0.0

        common = Counter(pred_tokens) & Counter(label_tokens)
        num_same = sum(common.values())

        if num_same == 0:
            return 0.0

        precision = num_same / len(pred_tokens)
        recall = num_same / len(label_tokens)
        return (2 * precision * recall) / (precision + recall)

    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    # 过滤高质量样本
    high_quality_samples = []
    for item in data:
        if item['clean_strategy'] in ['conservative', 'repetition_fix']:
            f1 = calculate_f1(item['predict'], item['label'])
            if f1 >= min_f1:
                high_quality_samples.append(item)

    print(f"\n💎 创建高质量数据集:")
    print(f"   原始样本: {len(data)}")
    print(f"   高质量样本 (F1≥{min_f1}): {len(high_quality_samples)}")
    print(f"   保留率: {len(high_quality_samples)/len(data)*100:.1f}%")

    # 保存高质量数据集
    with open(output_file, 'w', encoding='utf-8') as f:
        for item in high_quality_samples:
            save_item = {
                'prompt': item['prompt'],
                'predict': item['predict'],
                'label': item['label']
            }
            f.write(json.dumps(save_item, ensure_ascii=False) + '\n')

    print(f"   已保存到: {output_file}")

    return high_quality_samples

# 主分析流程
if __name__ == "__main__":
    # 使用前面生成的清理后数据
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/smart_cleaned_predictions.jsonl"

    # 分析失败模式
    conservative_f1_scores, garbage_samples = analyze_failure_patterns(input_file)

    # 生成改进计划
    generate_improvement_plan(conservative_f1_scores, garbage_samples)

    # 创建高质量数据集用于重训练
    high_quality_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/high_quality_samples.jsonl"
    high_quality_samples = create_filtered_dataset(input_file, high_quality_file, min_f1=0.2)

=== 失败模式分析 ===

各策略的详细分析:

🔍 垃圾输出分析 (19 样本):
  other: 8 样本
  Sure_repetition: 3 样本
  extreme_repetition: 8 样本

📊 Conservative策略分析 (526 样本):
  低分样本 (F1<0.1): 135 (25.7%)
  中等样本 (0.1≤F1<0.3): 228 (43.3%)
  良好样本 (F1≥0.3): 163 (31.0%)

🔍 低分样本问题分析:
  too_long: 49 样本 (36.3%)
  topic_mismatch: 103 样本 (76.3%)
  too_short: 69 样本 (51.1%)
  format_mismatch: 3 样本 (2.2%)

🎯 === 改进建议 ===

📝 短期改进（立即可行）:
1. **优化后处理管道**:
   - 当前Reranker已达到25%，表现不错
   - 可以调整智能Reranker的阈值来平衡精确度和召回率
   - 对96%的良好样本，可以进一步微调清理规则

2. **处理垃圾样本**:
   - 19 个垃圾样本可以考虑直接过滤或用规则生成
   - 这些样本可能来自训练数据中的问题样本

🔄 中期改进（重新训练）:
1. **数据质量优化**:
   - 清理训练数据中可能导致重复输出的样本
   - 增加数据多样性，特别是针对低分样本的类似场景
2. **训练参数调优**:
   - 增加repetition_penalty (建议1.1-1.2)
   - 调整temperature (建议0.7-0.9)
   - 考虑使用top_p采样 (0.8-0.95)
3. **训练策略改进**:
   - 使用更好的停止条件
   - 考虑添加更多轮次的训练
   - 可以尝试DPO (Direct Preference Optimization)

🚀 长期改进（架构升级）:
1. 考虑使用更大的模型或更新的架构
2. 实施RLHF (人类反馈强化学习)
3. 使用多轮对话训练提升一致性

📊 评估改进建议:
1. **当前Reranker 25%已经很好**，可以考虑:
   - 调整阈值获得更高分数
   - 使用语义相似度 (如sent

In [18]:
import json
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any, Tuple

class SmartCleaner:
    """更智能的预测清理器"""

    @staticmethod
    def detect_severe_repetition(text: str) -> bool:
        """检测严重的重复问题"""
        words = text.split()
        if len(words) < 5:
            return False

        # 检查是否有词汇重复超过总词数的50%
        word_counts = Counter(words)
        max_count = max(word_counts.values()) if word_counts else 0
        repetition_ratio = max_count / len(words)

        return repetition_ratio > 0.5

    @staticmethod
    def is_garbage_output(text: str) -> bool:
        """判断是否为垃圾输出"""
        # 检查各种垃圾输出模式
        text_lower = text.lower().strip()

        # 完全重复的单词
        words = text.split()
        if len(words) > 10:
            unique_words = set(words)
            if len(unique_words) <= 3:  # 只有很少的不同词汇
                return True

        # 检查是否主要是重复字符
        if len(text) > 50:
            char_counts = Counter(text.replace(' ', '').replace('\n', ''))
            if char_counts and max(char_counts.values()) > len(text) * 0.4:
                return True

        # 检查是否包含明显的格式错误
        error_patterns = [
            r'^(\w{1,4})\1{10,}',  # 短词重复很多次
            r'(assistant|system|user)\s*\1',  # 系统词重复
        ]

        for pattern in error_patterns:
            if re.search(pattern, text_lower):
                return True

        return False

    @staticmethod
    def smart_repetition_fix(text: str) -> str:
        """智能修复重复问题"""
        # 修复开头的重复问题
        text = re.sub(r'^(\w+)\s*\1\s*(assistant|system)?\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 修复句子级别的重复
        sentences = re.split(r'([.!?]+)', text)
        cleaned_sentences = []
        seen_sentences = set()

        for i in range(0, len(sentences), 2):
            if i < len(sentences):
                sentence = sentences[i].strip()
                sentence_key = re.sub(r'\s+', ' ', sentence.lower())

                if sentence_key and sentence_key not in seen_sentences and len(sentence) > 10:
                    seen_sentences.add(sentence_key)
                    cleaned_sentences.append(sentence)
                    if i + 1 < len(sentences):
                        cleaned_sentences.append(sentences[i + 1])  # 保留标点

        result = ''.join(cleaned_sentences).strip()

        # 如果清理后太短，使用原文
        if len(result) < len(text) * 0.3:
            return text

        return result

    @staticmethod
    def conservative_clean(text: str) -> str:
        """保守的清理策略"""
        original_text = text

        # 只修复明显的问题
        # 1. 移除格式标记
        text = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', text, flags=re.DOTALL)

        # 2. 修复开头的重复词和格式错误
        text = re.sub(r'^(\w+)\s*\1\s*assistant\s*', r'\1 ', text, flags=re.IGNORECASE)
        text = re.sub(r'^(\w+)\s*assistant\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 3. 移除standalone的assistant/system/user
        text = re.sub(r'\b(assistant|system|user)\s+', '', text, flags=re.IGNORECASE)

        # 4. 修复多余空格
        text = re.sub(r'\s+', ' ', text).strip()

        # 如果清理后明显变差，返回原文
        if len(text) < len(original_text) * 0.7:
            return original_text

        return text

    @classmethod
    def process_prediction(cls, text: str) -> Tuple[str, str]:
        """
        处理预测文本
        返回: (清理后的文本, 处理策略)
        """
        # 检查是否为严重的垃圾输出
        if cls.is_garbage_output(text):
            # 对于垃圾输出，尝试从中提取有意义的部分
            words = text.split()
            if words:
                # 尝试找到第一个正常的词开始
                meaningful_start = 0
                for i, word in enumerate(words):
                    if len(word) > 2 and word.lower() not in ['to', 'the', 'and', 'or', 'but']:
                        meaningful_start = i
                        break

                # 取前面一些正常词汇
                if meaningful_start < len(words) - 5:
                    reconstructed = ' '.join(words[meaningful_start:meaningful_start + 20])
                    return reconstructed, "garbage_reconstruction"

            return text, "garbage_kept"

        # 检查严重重复
        elif cls.detect_severe_repetition(text):
            cleaned = cls.smart_repetition_fix(text)
            return cleaned, "repetition_fix"

        # 保守清理
        else:
            cleaned = cls.conservative_clean(text)
            return cleaned, "conservative"

def advanced_evaluation(input_file: str, output_file: str = None):
    """高级评估和清理"""

    # 加载数据
    print("正在加载数据...")
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    print(f"加载了 {len(data)} 条数据")

    # 分析和清理
    print("正在分析和清理预测...")

    strategy_counts = Counter()
    cleaned_data = []
    severe_issues = []

    for i, item in enumerate(data):
        original_pred = item['predict']
        cleaned_pred, strategy = SmartCleaner.process_prediction(original_pred)

        strategy_counts[strategy] += 1

        # 记录严重问题的样本
        if strategy in ["garbage_reconstruction", "garbage_kept"]:
            severe_issues.append({
                'index': i,
                'strategy': strategy,
                'original': original_pred[:200] + "..." if len(original_pred) > 200 else original_pred,
                'cleaned': cleaned_pred[:200] + "..." if len(cleaned_pred) > 200 else cleaned_pred,
                'label': item['label'][:200] + "..." if len(item['label']) > 200 else item['label']
            })

        item_copy = item.copy()
        item_copy['original_predict'] = original_pred
        item_copy['predict'] = cleaned_pred
        item_copy['clean_strategy'] = strategy
        cleaned_data.append(item_copy)

    print(f"\n清理策略统计:")
    for strategy, count in strategy_counts.items():
        print(f"  {strategy}: {count} 个样本 ({count/len(data)*100:.1f}%)")

    # 显示严重问题样本
    if severe_issues:
        print(f"\n=== 发现 {len(severe_issues)} 个严重问题样本 ===")
        for i, issue in enumerate(severe_issues[:3]):
            print(f"\n样本 {issue['index']} ({issue['strategy']}):")
            print(f"原始: {issue['original']}")
            print(f"清理: {issue['cleaned']}")
            print(f"标签: {issue['label']}")

    # 重新评估
    print("\n正在重新计算评估指标...")

    def normalize_answer(s: str) -> str:
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def exact_match(prediction: str, ground_truth: str) -> float:
        return float(normalize_answer(prediction) == normalize_answer(ground_truth))

    def f1_score(prediction: str, ground_truth: str) -> float:
        pred_tokens = normalize_answer(prediction).split()
        truth_tokens = normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        if num_same == 0:
            return 0.0

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        return (2 * precision * recall) / (precision + recall)

    def smart_reranker(prediction: str, ground_truth: str) -> float:
        """智能Reranker，使用多个标准"""
        # 标准1: F1分数高于阈值
        f1 = f1_score(prediction, ground_truth)
        if f1 >= 0.6:
            return 1.0

        # 标准2: 关键词匹配
        pred_words = set(normalize_answer(prediction).split())
        truth_words = set(normalize_answer(ground_truth).split())

        if len(truth_words) > 0:
            keyword_overlap = len(pred_words & truth_words) / len(truth_words)
            if keyword_overlap >= 0.4:
                return 1.0

        return 0.0

    # 计算分数
    strategies = ['conservative', 'repetition_fix', 'garbage_reconstruction', 'garbage_kept']
    results_by_strategy = {}

    for strategy in strategies:
        strategy_items = [item for item in cleaned_data if item['clean_strategy'] == strategy]
        if not strategy_items:
            continue

        em_scores = []
        f1_scores = []
        reranker_scores = []

        for item in strategy_items:
            pred = item['predict']
            label = item['label']

            em_scores.append(exact_match(pred, label))
            f1_scores.append(f1_score(pred, label))
            reranker_scores.append(smart_reranker(pred, label))

        results_by_strategy[strategy] = {
            'count': len(strategy_items),
            'EM': np.mean(em_scores),
            'F1': np.mean(f1_scores),
            'Reranker': np.mean(reranker_scores)
        }

    # 计算总体分数
    all_em_scores = []
    all_f1_scores = []
    all_reranker_scores = []
    all_original_f1_scores = []

    for item in cleaned_data:
        pred = item['predict']
        orig_pred = item['original_predict']
        label = item['label']

        all_em_scores.append(exact_match(pred, label))
        all_f1_scores.append(f1_score(pred, label))
        all_reranker_scores.append(smart_reranker(pred, label))
        all_original_f1_scores.append(f1_score(orig_pred, label))

    overall_results = {
        'EM': np.mean(all_em_scores),
        'F1': np.mean(all_f1_scores),
        'Reranker': np.mean(all_reranker_scores)
    }

    original_f1 = np.mean(all_original_f1_scores)

    # 打印结果
    print(f"\n=== 智能清理结果 ===")
    print(f"总体改进:")
    print(f"  原始F1: {original_f1:.4f}")
    print(f"  清理后F1: {overall_results['F1']:.4f} (改进: {overall_results['F1'] - original_f1:+.4f})")
    print(f"  EM: {overall_results['EM']:.4f}")
    print(f"  智能Reranker: {overall_results['Reranker']:.4f}")

    print(f"\n按策略分类的结果:")
    for strategy, results in results_by_strategy.items():
        print(f"  {strategy} ({results['count']} 样本):")
        print(f"    EM: {results['EM']:.4f}, F1: {results['F1']:.4f}, Reranker: {results['Reranker']:.4f}")

    # 保存结果
    if output_file:
        print(f"\n正在保存清理后的数据到: {output_file}")
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in cleaned_data:
                # 保存清理后的数据，包含策略信息
                save_item = {
                    'prompt': item['prompt'],
                    'predict': item['predict'],
                    'label': item['label'],
                    'clean_strategy': item['clean_strategy']
                }
                f.write(json.dumps(save_item, ensure_ascii=False) + '\n')

    return overall_results, results_by_strategy

# 运行智能清理
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/high_quality_samples.jsonl"
    output_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/smart_cleaned_predictions.jsonl"

    overall_results, strategy_results = advanced_evaluation(input_file, output_file)

正在加载数据...
加载了 274 条数据
正在分析和清理预测...

清理策略统计:
  conservative: 273 个样本 (99.6%)
  repetition_fix: 1 个样本 (0.4%)

正在重新计算评估指标...

=== 智能清理结果 ===
总体改进:
  原始F1: 0.3159
  清理后F1: 0.3159 (改进: +0.0000)
  EM: 0.0000
  智能Reranker: 0.3613

按策略分类的结果:
  conservative (273 样本):
    EM: 0.0000, F1: 0.3163, Reranker: 0.3590
  repetition_fix (1 样本):
    EM: 0.0000, F1: 0.1993, Reranker: 1.0000

正在保存清理后的数据到: /content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/smart_cleaned_predictions.jsonl


In [19]:
import json
import os
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any, Tuple

class EvaluationMetrics:
    """计算各种NLP评估指标的类"""

    @staticmethod
    def normalize_answer(s: str) -> str:
        """标准化答案文本"""
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))

    @staticmethod
    def exact_match(prediction: str, ground_truth: str) -> float:
        """计算精确匹配(EM)分数"""
        return float(EvaluationMetrics.normalize_answer(prediction) ==
                    EvaluationMetrics.normalize_answer(ground_truth))

    @staticmethod
    def f1_score(prediction: str, ground_truth: str) -> float:
        """计算F1分数"""
        pred_tokens = EvaluationMetrics.normalize_answer(prediction).split()
        truth_tokens = EvaluationMetrics.normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        if precision + recall == 0:
            return 0.0

        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    @staticmethod
    def anls_score(prediction: str, ground_truth: str, threshold: float = 0.5) -> float:
        """计算ANLS分数"""
        def levenshtein_distance(s1: str, s2: str) -> int:
            if len(s1) < len(s2):
                return levenshtein_distance(s2, s1)

            if len(s2) == 0:
                return len(s1)

            previous_row = list(range(len(s2) + 1))
            for i, c1 in enumerate(s1):
                current_row = [i + 1]
                for j, c2 in enumerate(s2):
                    insertions = previous_row[j + 1] + 1
                    deletions = current_row[j] + 1
                    substitutions = previous_row[j] + (c1 != c2)
                    current_row.append(min(insertions, deletions, substitutions))
                previous_row = current_row

            return previous_row[-1]

        pred_norm = EvaluationMetrics.normalize_answer(prediction)
        truth_norm = EvaluationMetrics.normalize_answer(ground_truth)

        if len(truth_norm) == 0:
            return 1.0 if len(pred_norm) == 0 else 0.0

        edit_distance = levenshtein_distance(pred_norm, truth_norm)
        max_len = max(len(pred_norm), len(truth_norm))

        if max_len == 0:
            return 1.0

        normalized_distance = edit_distance / max_len
        similarity = 1 - normalized_distance

        return similarity if similarity >= threshold else 0.0

    @staticmethod
    def reranker_score(predictions: List[str], ground_truths: List[str], k: int = 1) -> float:
        """计算Reranker分数"""
        if not predictions or not ground_truths:
            return 0.0

        top_k_preds = predictions[:k] if len(predictions) >= k else predictions

        for pred in top_k_preds:
            for truth in ground_truths:
                if EvaluationMetrics.exact_match(pred, truth):
                    return 1.0

        return 0.0

def load_data(input_file: str) -> List[Dict[str, Any]]:
    """加载数据文件"""
    data = []

    if not os.path.exists(input_file):
        print(f"文件不存在: {input_file}")
        return data

    if input_file.endswith('.json'):
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
    elif input_file.endswith('.jsonl'):
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
    else:
        if os.path.isdir(input_file):
            for filename in os.listdir(input_file):
                if filename.endswith(('.json', '.jsonl')):
                    filepath = os.path.join(input_file, filename)
                    if filename.endswith('.json'):
                        with open(filepath, 'r', encoding='utf-8') as f:
                            file_data = json.load(f)
                            if isinstance(file_data, list):
                                data.extend(file_data)
                            else:
                                data.append(file_data)
                    else:
                        with open(filepath, 'r', encoding='utf-8') as f:
                            for line in f:
                                if line.strip():
                                    data.append(json.loads(line.strip()))

    return data

def evaluate_model(input_file: str, output_file: str = None) -> Dict[str, float]:
    """评估模型性能"""
    print(f"正在加载数据: {input_file}")
    data = load_data(input_file)

    if not data:
        print("未能加载任何数据")
        return {}

    print(f"加载了 {len(data)} 条数据")

    # 检查数据格式
    print("\n=== 数据格式检查 ===")
    sample_item = data[0]
    if isinstance(sample_item, dict):
        print("第一个样本的字段:")
        for key in sample_item.keys():
            value = sample_item[key]
            if isinstance(value, str):
                preview = value[:50] + "..." if len(value) > 50 else value
                print(f"  {key}: '{preview}'")

    # 提取预测结果和真实标签
    predictions = []
    ground_truths = []

    for item in data:
        if isinstance(item, dict):
            # 根据您的数据格式，使用 predict 和 label 字段
            if 'predict' in item and 'label' in item:
                predictions.append(str(item['predict']))
                ground_truths.append(str(item['label']))

    if not predictions:
        print("未能提取任何预测结果")
        return {}

    print(f"成功提取 {len(predictions)} 条样本")
    print(f"预测样例: {predictions[0][:100]}...")
    print(f"真实标签样例: {ground_truths[0][:100]}...")

    print(f"\n正在计算评估指标...")

    # 计算各种指标
    em_scores = []
    f1_scores = []
    anls_scores = []
    reranker_scores = []

    for pred, truth in zip(predictions, ground_truths):
        em_scores.append(EvaluationMetrics.exact_match(pred, truth))
        f1_scores.append(EvaluationMetrics.f1_score(pred, truth))
        anls_scores.append(EvaluationMetrics.anls_score(pred, truth))
        reranker_scores.append(EvaluationMetrics.reranker_score([pred], [truth]))

    # 计算平均分数
    results = {
        'EM': np.mean(em_scores),
        'F1': np.mean(f1_scores),
        'ANLS': np.mean(anls_scores),
        'Reranker': np.mean(reranker_scores),
        'sample_count': len(predictions)
    }

    # 打印结果
    print("\n=== 评估结果 ===")
    print(f"样本数量: {results['sample_count']}")
    print(f"EM Score: {results['EM']:.4f} ({results['EM']*100:.2f}%)")
    print(f"F1 Score: {results['F1']:.4f} ({results['F1']*100:.2f}%)")
    print(f"ANLS Score: {results['ANLS']:.4f} ({results['ANLS']*100:.2f}%)")
    print(f"Reranker Score: {results['Reranker']:.4f} ({results['Reranker']*100:.2f}%)")

    # 保存结果
    if output_file:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print(f"\n结果已保存到: {output_file}")

    return results

# 直接运行评估
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/high_quality_samples.jsonl"
    results = evaluate_model(input_file, "evaluation_results.json")

正在加载数据: /content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/high_quality_samples.jsonl
加载了 274 条数据

=== 数据格式检查 ===
第一个样本的字段:
  prompt: '<|im_start|>system
You are a helpful assistant.<|i...'
  predict: 'Tony the knowledge and authority rapport with the ...'
  label: 'Tony Robbins describes six core human needs that d...'
成功提取 274 条样本
预测样例: Tony the knowledge and authority rapport with the audience. 5. Benefits:: presenter offers an a offe...
真实标签样例: Tony Robbins describes six core human needs that drive our behaviors and motivations. These six need...

正在计算评估指标...

=== 评估结果 ===
样本数量: 274
EM Score: 0.0000 (0.00%)
F1 Score: 0.3159 (31.59%)
ANLS Score: 0.0110 (1.10%)
Reranker Score: 0.0000 (0.00%)

结果已保存到: evaluation_results.json


base

In [20]:
import json
import os
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any, Tuple

class EvaluationMetrics:
    """计算各种NLP评估指标的类"""

    @staticmethod
    def normalize_answer(s: str) -> str:
        """标准化答案文本"""
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))

    @staticmethod
    def exact_match(prediction: str, ground_truth: str) -> float:
        """计算精确匹配(EM)分数"""
        return float(EvaluationMetrics.normalize_answer(prediction) ==
                    EvaluationMetrics.normalize_answer(ground_truth))

    @staticmethod
    def f1_score(prediction: str, ground_truth: str) -> float:
        """计算F1分数"""
        pred_tokens = EvaluationMetrics.normalize_answer(prediction).split()
        truth_tokens = EvaluationMetrics.normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        if precision + recall == 0:
            return 0.0

        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    @staticmethod
    def anls_score(prediction: str, ground_truth: str, threshold: float = 0.5) -> float:
        """计算ANLS分数"""
        def levenshtein_distance(s1: str, s2: str) -> int:
            if len(s1) < len(s2):
                return levenshtein_distance(s2, s1)

            if len(s2) == 0:
                return len(s1)

            previous_row = list(range(len(s2) + 1))
            for i, c1 in enumerate(s1):
                current_row = [i + 1]
                for j, c2 in enumerate(s2):
                    insertions = previous_row[j + 1] + 1
                    deletions = current_row[j] + 1
                    substitutions = previous_row[j] + (c1 != c2)
                    current_row.append(min(insertions, deletions, substitutions))
                previous_row = current_row

            return previous_row[-1]

        pred_norm = EvaluationMetrics.normalize_answer(prediction)
        truth_norm = EvaluationMetrics.normalize_answer(ground_truth)

        if len(truth_norm) == 0:
            return 1.0 if len(pred_norm) == 0 else 0.0

        edit_distance = levenshtein_distance(pred_norm, truth_norm)
        max_len = max(len(pred_norm), len(truth_norm))

        if max_len == 0:
            return 1.0

        normalized_distance = edit_distance / max_len
        similarity = 1 - normalized_distance

        return similarity if similarity >= threshold else 0.0

    @staticmethod
    def reranker_score(predictions: List[str], ground_truths: List[str], k: int = 1) -> float:
        """计算Reranker分数"""
        if not predictions or not ground_truths:
            return 0.0

        top_k_preds = predictions[:k] if len(predictions) >= k else predictions

        for pred in top_k_preds:
            for truth in ground_truths:
                if EvaluationMetrics.exact_match(pred, truth):
                    return 1.0

        return 0.0

def load_data(input_file: str) -> List[Dict[str, Any]]:
    """加载数据文件"""
    data = []

    if not os.path.exists(input_file):
        print(f"文件不存在: {input_file}")
        return data

    if input_file.endswith('.json'):
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
    elif input_file.endswith('.jsonl'):
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data.append(json.loads(line.strip()))
    else:
        if os.path.isdir(input_file):
            for filename in os.listdir(input_file):
                if filename.endswith(('.json', '.jsonl')):
                    filepath = os.path.join(input_file, filename)
                    if filename.endswith('.json'):
                        with open(filepath, 'r', encoding='utf-8') as f:
                            file_data = json.load(f)
                            if isinstance(file_data, list):
                                data.extend(file_data)
                            else:
                                data.append(file_data)
                    else:
                        with open(filepath, 'r', encoding='utf-8') as f:
                            for line in f:
                                if line.strip():
                                    data.append(json.loads(line.strip()))

    return data

def evaluate_model(input_file: str, output_file: str = None) -> Dict[str, float]:
    """评估模型性能"""
    print(f"正在加载数据: {input_file}")
    data = load_data(input_file)

    if not data:
        print("未能加载任何数据")
        return {}

    print(f"加载了 {len(data)} 条数据")

    # 检查数据格式
    print("\n=== 数据格式检查 ===")
    sample_item = data[0]
    if isinstance(sample_item, dict):
        print("第一个样本的字段:")
        for key in sample_item.keys():
            value = sample_item[key]
            if isinstance(value, str):
                preview = value[:50] + "..." if len(value) > 50 else value
                print(f"  {key}: '{preview}'")

    # 提取预测结果和真实标签
    predictions = []
    ground_truths = []

    for item in data:
        if isinstance(item, dict):
            # 根据您的数据格式，使用 predict 和 label 字段
            if 'predict' in item and 'label' in item:
                predictions.append(str(item['predict']))
                ground_truths.append(str(item['label']))

    if not predictions:
        print("未能提取任何预测结果")
        return {}

    print(f"成功提取 {len(predictions)} 条样本")
    print(f"预测样例: {predictions[0][:100]}...")
    print(f"真实标签样例: {ground_truths[0][:100]}...")

    print(f"\n正在计算评估指标...")

    # 计算各种指标
    em_scores = []
    f1_scores = []
    anls_scores = []
    reranker_scores = []

    for pred, truth in zip(predictions, ground_truths):
        em_scores.append(EvaluationMetrics.exact_match(pred, truth))
        f1_scores.append(EvaluationMetrics.f1_score(pred, truth))
        anls_scores.append(EvaluationMetrics.anls_score(pred, truth))
        reranker_scores.append(EvaluationMetrics.reranker_score([pred], [truth]))

    # 计算平均分数
    results = {
        'EM': np.mean(em_scores),
        'F1': np.mean(f1_scores),
        'ANLS': np.mean(anls_scores),
        'Reranker': np.mean(reranker_scores),
        'sample_count': len(predictions)
    }

    # 打印结果
    print("\n=== 评估结果 ===")
    print(f"样本数量: {results['sample_count']}")
    print(f"EM Score: {results['EM']:.4f} ({results['EM']*100:.2f}%)")
    print(f"F1 Score: {results['F1']:.4f} ({results['F1']*100:.2f}%)")
    print(f"ANLS Score: {results['ANLS']:.4f} ({results['ANLS']*100:.2f}%)")
    print(f"Reranker Score: {results['Reranker']:.4f} ({results['Reranker']*100:.2f}%)")

    # 保存结果
    if output_file:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        print(f"\n结果已保存到: {output_file}")

    return results

# 直接运行评估
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/basevbshare/generated_predictions.jsonl"
    results = evaluate_model(input_file, "evaluation_results.json")

正在加载数据: /content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/basevbshare/generated_predictions.jsonl
加载了 548 条数据

=== 数据格式检查 ===
第一个样本的字段:
  prompt: '<|im_start|>system
You are a helpful assistant.<|i...'
  predict: 'Tony the expertise and providing credibility.
 the...'
  label: 'Tony Robbins describes six core human needs that d...'
成功提取 548 条样本
预测样例: Tony the expertise and providing credibility.
 the audience.
5. Call: The presenter offers the offer...
真实标签样例: Tony Robbins describes six core human needs that drive our behaviors and motivations. These six need...

正在计算评估指标...

=== 评估结果 ===
样本数量: 548
EM Score: 0.0018 (0.18%)
F1 Score: 0.2021 (20.21%)
ANLS Score: 0.0048 (0.48%)
Reranker Score: 0.0018 (0.18%)

结果已保存到: evaluation_results.json


In [1]:
import json
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any

class PredictionCleaner:
    """用于清理和改进预测结果的类"""

    @staticmethod
    def remove_repetitions(text: str, max_repeat: int = 3) -> str:
        """移除过度重复的词汇"""
        words = text.split()
        if not words:
            return text

        cleaned_words = []
        word_count = Counter()

        for word in words:
            word_lower = word.lower().strip('.,!?;:')
            word_count[word_lower] += 1

            # 如果这个词已经重复太多次，跳过
            if word_count[word_lower] <= max_repeat:
                cleaned_words.append(word)
            elif len(cleaned_words) == 0:  # 保留至少一个词
                cleaned_words.append(word)

        return ' '.join(cleaned_words)

    @staticmethod
    def remove_format_tokens(text: str) -> str:
        """移除格式标记和错误的系统词汇"""
        # 移除chat格式标记
        text = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', text, flags=re.DOTALL)
        text = re.sub(r'<\|.*?\|>', '', text)

        # 移除assistant相关词汇
        text = re.sub(r'\b(assistant|system|user)\b', '', text, flags=re.IGNORECASE)

        # 移除开头的重复名字（如 "Tony Tonyassistant"）
        text = re.sub(r'^(\w+)\s*\1\s*(assistant|system)?\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 移除其他常见错误格式
        text = re.sub(r'\bhelpful assistant\b', '', text, flags=re.IGNORECASE)

        return text.strip()

    @staticmethod
    def remove_sentence_repetitions(text: str) -> str:
        """移除重复的句子"""
        sentences = re.split(r'[.!?]+', text)
        seen_sentences = set()
        unique_sentences = []

        for sentence in sentences:
            sentence_clean = sentence.strip().lower()
            if sentence_clean and sentence_clean not in seen_sentences:
                seen_sentences.add(sentence_clean)
                unique_sentences.append(sentence.strip())

        return '. '.join(s for s in unique_sentences if s) + ('.' if unique_sentences else '')

    @staticmethod
    def fix_common_errors(text: str) -> str:
        """修复常见错误"""
        # 修复粘连的词（如 "Tonyassistant" -> "Tony"）
        text = re.sub(r'(\w+)assistant', r'\1', text, flags=re.IGNORECASE)
        text = re.sub(r'(\w+)system', r'\1', text, flags=re.IGNORECASE)

        # 修复多余的换行和空格
        text = re.sub(r'\n+', ' ', text)
        text = re.sub(r'\s+', ' ', text)

        # 修复开头的错误（移除开头重复的词）
        words = text.split()
        if len(words) > 1 and words[0].lower() == words[1].lower():
            text = ' '.join(words[1:])

        return text.strip()

    @classmethod
    def clean_prediction(cls, text: str) -> str:
        """综合清理预测文本"""
        # 按步骤清理
        text = cls.remove_format_tokens(text)
        text = cls.fix_common_errors(text)
        text = cls.remove_repetitions(text, max_repeat=2)  # 更严格的重复控制
        text = cls.remove_sentence_repetitions(text)

        return text.strip()

def clean_and_evaluate(input_file: str, output_file: str = None):
    """清理预测并重新评估"""

    # 加载数据
    print("正在加载数据...")
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    print(f"加载了 {len(data)} 条数据")

    # 清理预测
    print("正在清理预测结果...")
    cleaned_data = []
    improvement_count = 0

    for i, item in enumerate(data):
        original_pred = item['predict']
        cleaned_pred = PredictionCleaner.clean_prediction(original_pred)

        # 检查是否有改进
        if cleaned_pred != original_pred:
            improvement_count += 1

        cleaned_item = item.copy()
        cleaned_item['original_predict'] = original_pred
        cleaned_item['predict'] = cleaned_pred
        cleaned_data.append(cleaned_item)

        # 显示前几个清理示例
        if i < 5:
            print(f"\n--- 样本 {i+1} 清理对比 ---")
            print(f"原始: {original_pred[:150]}...")
            print(f"清理后: {cleaned_pred[:150]}...")

    print(f"\n改进了 {improvement_count} 个样本 ({improvement_count/len(data)*100:.1f}%)")

    # 重新计算评估指标
    print("\n正在重新计算评估指标...")

    # 评估函数
    def normalize_answer(s: str) -> str:
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def exact_match(prediction: str, ground_truth: str) -> float:
        return float(normalize_answer(prediction) == normalize_answer(ground_truth))

    def f1_score(prediction: str, ground_truth: str) -> float:
        pred_tokens = normalize_answer(prediction).split()
        truth_tokens = normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        if precision + recall == 0:
            return 0.0

        return (2 * precision * recall) / (precision + recall)

    def partial_match_score(prediction: str, ground_truth: str, threshold: float = 0.7) -> float:
        """计算部分匹配分数，用于改进的Reranker"""
        f1 = f1_score(prediction, ground_truth)
        return 1.0 if f1 >= threshold else 0.0

    # 计算原始分数和清理后分数
    original_em_scores = []
    original_f1_scores = []
    original_reranker_scores = []

    cleaned_em_scores = []
    cleaned_f1_scores = []
    cleaned_reranker_scores = []

    for item in cleaned_data:
        original_pred = item['original_predict']
        cleaned_pred = item['predict']
        label = item['label']

        # 原始分数
        original_em_scores.append(exact_match(original_pred, label))
        original_f1_scores.append(f1_score(original_pred, label))
        original_reranker_scores.append(partial_match_score(original_pred, label, 0.5))

        # 清理后分数
        cleaned_em_scores.append(exact_match(cleaned_pred, label))
        cleaned_f1_scores.append(f1_score(cleaned_pred, label))
        cleaned_reranker_scores.append(partial_match_score(cleaned_pred, label, 0.5))

    # 计算平均分数
    original_results = {
        'EM': np.mean(original_em_scores),
        'F1': np.mean(original_f1_scores),
        'Reranker': np.mean(original_reranker_scores)
    }

    cleaned_results = {
        'EM': np.mean(cleaned_em_scores),
        'F1': np.mean(cleaned_f1_scores),
        'Reranker': np.mean(cleaned_reranker_scores)
    }

    # 打印对比结果
    print("\n=== 清理前后对比 ===")
    print(f"样本数量: {len(data)}")
    print(f"\n{'指标':<15} {'原始':<10} {'清理后':<10} {'改进':<10}")
    print("-" * 50)

    for metric in ['EM', 'F1', 'Reranker']:
        original = original_results[metric]
        cleaned = cleaned_results[metric]
        improvement = cleaned - original
        print(f"{metric:<15} {original:<10.4f} {cleaned:<10.4f} {improvement:>+.4f}")

    # 找出改进最明显的样本
    print("\n=== 改进最明显的样本 ===")
    improvements = []
    for i, item in enumerate(cleaned_data):
        original_f1 = f1_score(item['original_predict'], item['label'])
        cleaned_f1 = f1_score(item['predict'], item['label'])
        improvement = cleaned_f1 - original_f1
        if improvement > 0.1:  # 改进超过0.1的样本
            improvements.append((i, improvement, item))

    improvements.sort(key=lambda x: x[1], reverse=True)
    for i, (idx, improvement, item) in enumerate(improvements[:3]):
        print(f"\n样本 {idx} (改进 +{improvement:.3f}):")
        print(f"原始: {item['original_predict'][:100]}...")
        print(f"清理: {item['predict'][:100]}...")
        print(f"标签: {item['label'][:100]}...")

    # 保存清理后的数据
    if output_file:
        print(f"\n正在保存清理后的数据到: {output_file}")
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in cleaned_data:
                # 只保存清理后的predict，移除original_predict以节省空间
                save_item = {k: v for k, v in item.items() if k != 'original_predict'}
                f.write(json.dumps(save_item, ensure_ascii=False) + '\n')

    return original_results, cleaned_results

# 运行清理和评估
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/basevbshare/generated_predictions.jsonl"
    output_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/cleaned_predictions.jsonl"

    original_results, cleaned_results = clean_and_evaluate(input_file, output_file)

正在加载数据...
加载了 548 条数据
正在清理预测结果...

--- 样本 1 清理对比 ---
原始: Tony the expertise and providing credibility.
 the audience.
5. Call: The presenter offers the offer or service they they are are, and why why in solv...
清理后: Tony the expertise and providing credibility. the audience. 5. Call: presenter offers offer or service they they are are, and why why in solves a prob...

--- 样本 2 清理对比 ---
原始: Sure, I can tell you how to tell if a customer segment is well segmented. Here are three bullet points to help you out:

1. Does the customer segment ...
清理后: Sure, I can tell you how to tell if a customer segment is well segmented. Here are three bullet points to help you out: 1. Does the customer segment h...

--- 样本 3 清理对比 ---
原始: To replace the string "This is a new {object} at {place}" with a Map, you can use the `Map` class in Java. Here's how you can do it:

1. First, you ne...
清理后: To replace the string "This is a new {object} at {place}" with a Map, you can use the `Map` class in Java. Here'

vb

llm base


In [3]:
!pip install -U flagembedding


Collecting flagembedding
  Downloading FlagEmbedding-1.3.5.tar.gz (163 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/163.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m163.8/163.9 kB[0m [31m4.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.9/163.9 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ir-datasets (from flagembedding)
  Downloading ir_datasets-0.5.11-py3-none-any.whl.metadata (12 kB)
Collecting inscriptis>=2.2.0 (from ir-datasets->flagembedding)
  Downloading inscriptis-2.6.0-py3-none-any.whl.metadata (25 kB)
Collecting trec-car-tools>=2.5.4 (from ir-datasets->flagembedding)
  Downloading trec_car_tools-2.6-py3-none-any.whl.metadata (640 bytes)
Collecting lz4>=3.1.10 (from ir-datasets->flagembedding)
  Downloading lz4-4.4.4-cp312-cp312-manylinux_2_17_x86_64.manylinu

In [None]:
import json
import time
import os
from openai import OpenAI

# 设置API Key
os.environ["OPENAI_API_KEY"] = ""
client = OpenAI()

def read_jsonl_data(path):
    """读取JSONL数据"""
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                item = json.loads(line)
                data.append({
                    'prompt': item.get('prompt', ''),
                    'prediction': item.get('predict', ''),
                    'reference': item.get('label', '')
                })
    return data

def gpt_judge(prediction, reference, prompt=""):
    """直接用GPT判断"""

    # 识别任务类型
    task_type = "general conversation"
    if prompt:
        prompt_lower = prompt.lower()
        if any(word in prompt_lower for word in ['code', 'programming', 'function', 'algorithm']):
            task_type = "coding"
        elif any(word in prompt_lower for word in ['explain', 'what is', 'describe', 'definition']):
            task_type = "explanation"
        elif any(word in prompt_lower for word in ['list', 'steps', 'how to', 'instruction']):
            task_type = "instruction"
        elif any(word in prompt_lower for word in ['analyze', 'compare', 'evaluate']):
            task_type = "analysis"

    judge_prompt = f"""You are evaluating an AI assistant's response for a {task_type} task.

Evaluation Guidelines:
- Focus on factual correctness and helpfulness
- Accept different but semantically equivalent answers
- Ignore formatting, style, and minor wording differences
- Consider the core message and information accuracy
- Be reasonably lenient - if the prediction conveys the same key information as the reference, it should be considered correct
- For explanations: focus on conceptual accuracy rather than exact phrasing
- For instructions: check if the steps achieve the same goal
- For analysis: evaluate logical reasoning and key conclusions

Important: Many responses may be worded differently but still be factually correct and helpful.

Output format: {{"score": 1, "reason": "explanation"}} for correct responses
               {{"score": 0, "reason": "explanation"}} for incorrect responses

Score 1: The prediction is factually correct and helpful (even if expressed differently)
Score 0: The prediction is factually wrong, misleading, or significantly unhelpful"""

    # 限制文本长度以节省token
    max_length = 1000
    pred_text = prediction[:max_length]
    ref_text = reference[:max_length]

    # 添加省略号如果被截断
    if len(prediction) > max_length:
        pred_text += " [truncated...]"
    if len(reference) > max_length:
        ref_text += " [truncated...]"

    user_message = f"""Task Type: {task_type}

AI Prediction:
{pred_text}

Reference Answer:
{ref_text}

Is the AI prediction factually correct and helpful compared to the reference? Output JSON only."""

    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            temperature=0.1,
            messages=[
                {"role": "system", "content": judge_prompt},
                {"role": "user", "content": user_message}
            ],
            max_tokens=150
        )

        content = response.choices[0].message.content.strip()

        # 解析JSON
        try:
            # 查找JSON部分
            start = content.find("{")
            end = content.rfind("}") + 1

            if start != -1 and end > start:
                json_text = content[start:end]
                result = json.loads(json_text)

                score = int(result.get("score", 0))
                reason = result.get("reason", "no reason provided")

                return score, reason
            else:
                # 如果没有找到JSON，尝试从文本中提取
                if "score" in content.lower():
                    if "1" in content and ("correct" in content.lower() or "accurate" in content.lower()):
                        return 1, "extracted_positive"
                    else:
                        return 0, "extracted_negative"
                return 0, f"parse_failed: {content}"

        except json.JSONDecodeError:
            # JSON解析失败，基于关键词判断
            content_lower = content.lower()
            if any(word in content_lower for word in ["correct", "accurate", "helpful", "good"]):
                return 1, f"keyword_positive: {content[:50]}"
            else:
                return 0, f"keyword_negative: {content[:50]}"

    except Exception as e:
        return 0, f"api_error: {str(e)}"

def evaluate_direct_gpt(input_path, sample_size=None, save_results=True):
    """直接用GPT评估所有样本"""

    print("=== 直接GPT判断评估 ===")
    print(f"输入文件: {input_path}")

    # 读取数据
    data = read_jsonl_data(input_path)
    print(f"加载了 {len(data)} 个样本")

    # 采样
    if sample_size and sample_size < len(data):
        import random
        data = random.sample(data, sample_size)
        print(f"采样 {sample_size} 个样本进行评估")

    results = []
    correct_count = 0
    error_count = 0

    print(f"\n开始处理 {len(data)} 个样本...")

    for i, item in enumerate(data):
        print(f"处理样本 {i+1}/{len(data)}", end="")

        pred = item['prediction']
        ref = item['reference']
        prompt = item['prompt']

        # 直接调用GPT判断
        score, reason = gpt_judge(pred, ref, prompt)

        if score == 1:
            correct_count += 1
            print(" ✓")
        else:
            print(f" ✗ ({reason[:30]}...)")

        if "api_error" in reason:
            error_count += 1

        results.append({
            'index': i,
            'prompt': prompt[:100] + "..." if len(prompt) > 100 else prompt,
            'prediction': pred[:200] + "..." if len(pred) > 200 else pred,
            'reference': ref[:200] + "..." if len(ref) > 200 else ref,
            'gpt_score': score,
            'gpt_reason': reason
        })

        # API限速
        time.sleep(0.5)

        # 每10个样本显示进度
        if (i + 1) % 10 == 0:
            current_accuracy = correct_count / (i + 1)
            print(f"  进度: {i+1}/{len(data)}, 当前准确率: {current_accuracy:.3f}")

    # 计算最终结果
    total_samples = len(results)
    accuracy = correct_count / total_samples

    print(f"\n=== 评估结果 ===")
    print(f"总样本数: {total_samples}")
    print(f"正确样本: {correct_count}")
    print(f"错误样本: {total_samples - correct_count}")
    print(f"API错误: {error_count}")
    print(f"最终准确率: {accuracy:.4f} ({accuracy*100:.2f}%)")

    # 显示一些示例
    print(f"\n=== 正确样本示例 ===")
    correct_examples = [r for r in results if r['gpt_score'] == 1][:2]
    for ex in correct_examples:
        print(f"预测: {ex['prediction']}")
        print(f"参考: {ex['reference']}")
        print(f"原因: {ex['gpt_reason']}")
        print()

    print(f"=== 错误样本示例 ===")
    incorrect_examples = [r for r in results if r['gpt_score'] == 0][:2]
    for ex in incorrect_examples:
        print(f"预测: {ex['prediction']}")
        print(f"参考: {ex['reference']}")
        print(f"原因: {ex['gpt_reason']}")
        print()

    # 保存结果
    if save_results:
        output_file = f"gpt_direct_evaluation_results.json"
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump({
                'summary': {
                    'total_samples': total_samples,
                    'correct_count': correct_count,
                    'accuracy': accuracy,
                    'api_errors': error_count
                },
                'detailed_results': results
            }, f, indent=2, ensure_ascii=False)

        print(f"\n结果已保存到: {output_file}")

    return results

def quick_test(input_path, n_samples=3):
    """快速测试GPT调用"""
    print("=== 快速GPT测试 ===")

    data = read_jsonl_data(input_path)
    if not data:
        print("无法读取数据")
        return

    # 测试前几个样本
    for i in range(min(n_samples, len(data))):
        item = data[i]
        print(f"\n--- 测试样本 {i+1} ---")
        print(f"预测: {item['prediction'][:100]}...")
        print(f"参考: {item['reference'][:100]}...")

        score, reason = gpt_judge(item['prediction'], item['reference'], item['prompt'])

        print(f"GPT评分: {score}")
        print(f"原因: {reason}")

        time.sleep(1)  # 短暂暂停

# 使用示例
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/basevbshare/generated_predictions.jsonl"

    results = evaluate_direct_gpt(input_file, sample_size=None, save_results=True)

=== 直接GPT判断评估 ===
输入文件: /content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/basevbshare/generated_predictions.jsonl
加载了 548 个样本

开始处理 548 个样本...
处理样本 1/548 ✗ (The AI prediction does not acc...)
处理样本 2/548 ✗ (The AI prediction does not pro...)
处理样本 3/548 ✗ (The AI prediction incorrectly ...)
处理样本 4/548 ✓
处理样本 5/548 ✗ (The AI prediction does not add...)
处理样本 6/548 ✗ (The AI prediction does not pro...)
处理样本 7/548 ✗ (The AI prediction provides det...)
处理样本 8/548 ✗ (The AI prediction does not pro...)
处理样本 9/548 ✗ (The AI prediction focuses sole...)
处理样本 10/548 ✗ (The AI prediction is incoheren...)
  进度: 10/548, 当前准确率: 0.100
处理样本 11/548 ✗ (The AI prediction provides a J...)
处理样本 12/548 ✗ (The AI prediction provides an ...)
处理样本 13/548 ✗ (The AI prediction contains num...)
处理样本 14/548 ✗ (The AI prediction is unclear, ...)
处理样本 15/548 ✗ (The AI prediction 'Toologicalv...)
处理样本 16/548 ✗ (The AI prediction does not add...)
处理样本 17/548 ✗ (The AI prediction does not pro...)
处理样本 18/548 ✗ (The AI p

lora

In [None]:
import json
import time
import os
from openai import OpenAI

# 设置API Key
os.environ["OPENAI_API_KEY"] = ""
client = OpenAI()

def read_jsonl_data(path):
    """读取JSONL数据"""
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                item = json.loads(line)
                data.append({
                    'prompt': item.get('prompt', ''),
                    'prediction': item.get('predict', ''),
                    'reference': item.get('label', '')
                })
    return data

def gpt_judge(prediction, reference, prompt=""):
    """直接用GPT判断"""

    # 识别任务类型
    task_type = "general conversation"
    if prompt:
        prompt_lower = prompt.lower()
        if any(word in prompt_lower for word in ['code', 'programming', 'function', 'algorithm']):
            task_type = "coding"
        elif any(word in prompt_lower for word in ['explain', 'what is', 'describe', 'definition']):
            task_type = "explanation"
        elif any(word in prompt_lower for word in ['list', 'steps', 'how to', 'instruction']):
            task_type = "instruction"
        elif any(word in prompt_lower for word in ['analyze', 'compare', 'evaluate']):
            task_type = "analysis"

    judge_prompt = f"""You are evaluating an AI assistant's response for a {task_type} task.

Evaluation Guidelines:
- Focus on factual correctness and helpfulness
- Accept different but semantically equivalent answers
- Ignore formatting, style, and minor wording differences
- Consider the core message and information accuracy
- Be reasonably lenient - if the prediction conveys the same key information as the reference, it should be considered correct
- For explanations: focus on conceptual accuracy rather than exact phrasing
- For instructions: check if the steps achieve the same goal
- For analysis: evaluate logical reasoning and key conclusions

Important: Many responses may be worded differently but still be factually correct and helpful.

Output format: {{"score": 1, "reason": "explanation"}} for correct responses
               {{"score": 0, "reason": "explanation"}} for incorrect responses

Score 1: The prediction is factually correct and helpful (even if expressed differently)
Score 0: The prediction is factually wrong, misleading, or significantly unhelpful"""

    # 限制文本长度以节省token
    max_length = 1000
    pred_text = prediction[:max_length]
    ref_text = reference[:max_length]

    # 添加省略号如果被截断
    if len(prediction) > max_length:
        pred_text += " [truncated...]"
    if len(reference) > max_length:
        ref_text += " [truncated...]"

    user_message = f"""Task Type: {task_type}

AI Prediction:
{pred_text}

Reference Answer:
{ref_text}

Is the AI prediction factually correct and helpful compared to the reference? Output JSON only."""

    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            temperature=0.1,
            messages=[
                {"role": "system", "content": judge_prompt},
                {"role": "user", "content": user_message}
            ],
            max_tokens=150
        )

        content = response.choices[0].message.content.strip()

        # 解析JSON
        try:
            # 查找JSON部分
            start = content.find("{")
            end = content.rfind("}") + 1

            if start != -1 and end > start:
                json_text = content[start:end]
                result = json.loads(json_text)

                score = int(result.get("score", 0))
                reason = result.get("reason", "no reason provided")

                return score, reason
            else:
                # 如果没有找到JSON，尝试从文本中提取
                if "score" in content.lower():
                    if "1" in content and ("correct" in content.lower() or "accurate" in content.lower()):
                        return 1, "extracted_positive"
                    else:
                        return 0, "extracted_negative"
                return 0, f"parse_failed: {content}"

        except json.JSONDecodeError:
            # JSON解析失败，基于关键词判断
            content_lower = content.lower()
            if any(word in content_lower for word in ["correct", "accurate", "helpful", "good"]):
                return 1, f"keyword_positive: {content[:50]}"
            else:
                return 0, f"keyword_negative: {content[:50]}"

    except Exception as e:
        return 0, f"api_error: {str(e)}"

def evaluate_direct_gpt(input_path, sample_size=None, save_results=True):
    """直接用GPT评估所有样本"""

    print("=== 直接GPT判断评估 ===")
    print(f"输入文件: {input_path}")

    # 读取数据
    data = read_jsonl_data(input_path)
    print(f"加载了 {len(data)} 个样本")

    # 采样
    if sample_size and sample_size < len(data):
        import random
        data = random.sample(data, sample_size)
        print(f"采样 {sample_size} 个样本进行评估")

    results = []
    correct_count = 0
    error_count = 0

    print(f"\n开始处理 {len(data)} 个样本...")

    for i, item in enumerate(data):
        print(f"处理样本 {i+1}/{len(data)}", end="")

        pred = item['prediction']
        ref = item['reference']
        prompt = item['prompt']

        # 直接调用GPT判断
        score, reason = gpt_judge(pred, ref, prompt)

        if score == 1:
            correct_count += 1
            print(" ✓")
        else:
            print(f" ✗ ({reason[:30]}...)")

        if "api_error" in reason:
            error_count += 1

        results.append({
            'index': i,
            'prompt': prompt[:100] + "..." if len(prompt) > 100 else prompt,
            'prediction': pred[:200] + "..." if len(pred) > 200 else pred,
            'reference': ref[:200] + "..." if len(ref) > 200 else ref,
            'gpt_score': score,
            'gpt_reason': reason
        })

        # API限速
        time.sleep(0.5)

        # 每10个样本显示进度
        if (i + 1) % 10 == 0:
            current_accuracy = correct_count / (i + 1)
            print(f"  进度: {i+1}/{len(data)}, 当前准确率: {current_accuracy:.3f}")

    # 计算最终结果
    total_samples = len(results)
    accuracy = correct_count / total_samples

    print(f"\n=== 评估结果 ===")
    print(f"总样本数: {total_samples}")
    print(f"正确样本: {correct_count}")
    print(f"错误样本: {total_samples - correct_count}")
    print(f"API错误: {error_count}")
    print(f"最终准确率: {accuracy:.4f} ({accuracy*100:.2f}%)")

    # 显示一些示例
    print(f"\n=== 正确样本示例 ===")
    correct_examples = [r for r in results if r['gpt_score'] == 1][:2]
    for ex in correct_examples:
        print(f"预测: {ex['prediction']}")
        print(f"参考: {ex['reference']}")
        print(f"原因: {ex['gpt_reason']}")
        print()

    print(f"=== 错误样本示例 ===")
    incorrect_examples = [r for r in results if r['gpt_score'] == 0][:2]
    for ex in incorrect_examples:
        print(f"预测: {ex['prediction']}")
        print(f"参考: {ex['reference']}")
        print(f"原因: {ex['gpt_reason']}")
        print()

    # 保存结果
    if save_results:
        output_file = f"gpt_direct_evaluation_results.json"
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump({
                'summary': {
                    'total_samples': total_samples,
                    'correct_count': correct_count,
                    'accuracy': accuracy,
                    'api_errors': error_count
                },
                'detailed_results': results
            }, f, indent=2, ensure_ascii=False)

        print(f"\n结果已保存到: {output_file}")

    return results

def quick_test(input_path, n_samples=3):
    """快速测试GPT调用"""
    print("=== 快速GPT测试 ===")

    data = read_jsonl_data(input_path)
    if not data:
        print("无法读取数据")
        return

    # 测试前几个样本
    for i in range(min(n_samples, len(data))):
        item = data[i]
        print(f"\n--- 测试样本 {i+1} ---")
        print(f"预测: {item['prediction'][:100]}...")
        print(f"参考: {item['reference'][:100]}...")

        score, reason = gpt_judge(item['prediction'], item['reference'], item['prompt'])

        print(f"GPT评分: {score}")
        print(f"原因: {reason}")

        time.sleep(1)  # 短暂暂停

# 使用示例
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl"

    results = evaluate_direct_gpt(input_file, sample_size=None, save_results=True)

=== 直接GPT判断评估 ===
输入文件: /content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl
加载了 548 个样本

开始处理 548 个样本...
处理样本 1/548 ✗ (The AI prediction incorrectly ...)
处理样本 2/548

KeyboardInterrupt: 

vb

In [26]:
import json
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any

class PredictionCleaner:
    """用于清理和改进预测结果的类"""

    @staticmethod
    def remove_repetitions(text: str, max_repeat: int = 3) -> str:
        """移除过度重复的词汇"""
        words = text.split()
        if not words:
            return text

        cleaned_words = []
        word_count = Counter()

        for word in words:
            word_lower = word.lower().strip('.,!?;:')
            word_count[word_lower] += 1

            # 如果这个词已经重复太多次，跳过
            if word_count[word_lower] <= max_repeat:
                cleaned_words.append(word)
            elif len(cleaned_words) == 0:  # 保留至少一个词
                cleaned_words.append(word)

        return ' '.join(cleaned_words)

    @staticmethod
    def remove_format_tokens(text: str) -> str:
        """移除格式标记和错误的系统词汇"""
        # 移除chat格式标记
        text = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', text, flags=re.DOTALL)
        text = re.sub(r'<\|.*?\|>', '', text)

        # 移除assistant相关词汇
        text = re.sub(r'\b(assistant|system|user)\b', '', text, flags=re.IGNORECASE)

        # 移除开头的重复名字（如 "Tony Tonyassistant"）
        text = re.sub(r'^(\w+)\s*\1\s*(assistant|system)?\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 移除其他常见错误格式
        text = re.sub(r'\bhelpful assistant\b', '', text, flags=re.IGNORECASE)

        return text.strip()

    @staticmethod
    def remove_sentence_repetitions(text: str) -> str:
        """移除重复的句子"""
        sentences = re.split(r'[.!?]+', text)
        seen_sentences = set()
        unique_sentences = []

        for sentence in sentences:
            sentence_clean = sentence.strip().lower()
            if sentence_clean and sentence_clean not in seen_sentences:
                seen_sentences.add(sentence_clean)
                unique_sentences.append(sentence.strip())

        return '. '.join(s for s in unique_sentences if s) + ('.' if unique_sentences else '')

    @staticmethod
    def fix_common_errors(text: str) -> str:
        """修复常见错误"""
        # 修复粘连的词（如 "Tonyassistant" -> "Tony"）
        text = re.sub(r'(\w+)assistant', r'\1', text, flags=re.IGNORECASE)
        text = re.sub(r'(\w+)system', r'\1', text, flags=re.IGNORECASE)

        # 修复多余的换行和空格
        text = re.sub(r'\n+', ' ', text)
        text = re.sub(r'\s+', ' ', text)

        # 修复开头的错误（移除开头重复的词）
        words = text.split()
        if len(words) > 1 and words[0].lower() == words[1].lower():
            text = ' '.join(words[1:])

        return text.strip()

    @classmethod
    def clean_prediction(cls, text: str) -> str:
        """综合清理预测文本"""
        # 按步骤清理
        text = cls.remove_format_tokens(text)
        text = cls.fix_common_errors(text)
        text = cls.remove_repetitions(text, max_repeat=2)  # 更严格的重复控制
        text = cls.remove_sentence_repetitions(text)

        return text.strip()

def clean_and_evaluate(input_file: str, output_file: str = None):
    """清理预测并重新评估"""

    # 加载数据
    print("正在加载数据...")
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    print(f"加载了 {len(data)} 条数据")

    # 清理预测
    print("正在清理预测结果...")
    cleaned_data = []
    improvement_count = 0

    for i, item in enumerate(data):
        original_pred = item['predict']
        cleaned_pred = PredictionCleaner.clean_prediction(original_pred)

        # 检查是否有改进
        if cleaned_pred != original_pred:
            improvement_count += 1

        cleaned_item = item.copy()
        cleaned_item['original_predict'] = original_pred
        cleaned_item['predict'] = cleaned_pred
        cleaned_data.append(cleaned_item)

        # 显示前几个清理示例
        if i < 5:
            print(f"\n--- 样本 {i+1} 清理对比 ---")
            print(f"原始: {original_pred[:150]}...")
            print(f"清理后: {cleaned_pred[:150]}...")

    print(f"\n改进了 {improvement_count} 个样本 ({improvement_count/len(data)*100:.1f}%)")

    # 重新计算评估指标
    print("\n正在重新计算评估指标...")

    # 评估函数
    def normalize_answer(s: str) -> str:
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def exact_match(prediction: str, ground_truth: str) -> float:
        return float(normalize_answer(prediction) == normalize_answer(ground_truth))

    def f1_score(prediction: str, ground_truth: str) -> float:
        pred_tokens = normalize_answer(prediction).split()
        truth_tokens = normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        if precision + recall == 0:
            return 0.0

        return (2 * precision * recall) / (precision + recall)

    def partial_match_score(prediction: str, ground_truth: str, threshold: float = 0.7) -> float:
        """计算部分匹配分数，用于改进的Reranker"""
        f1 = f1_score(prediction, ground_truth)
        return 1.0 if f1 >= threshold else 0.0

    # 计算原始分数和清理后分数
    original_em_scores = []
    original_f1_scores = []
    original_reranker_scores = []

    cleaned_em_scores = []
    cleaned_f1_scores = []
    cleaned_reranker_scores = []

    for item in cleaned_data:
        original_pred = item['original_predict']
        cleaned_pred = item['predict']
        label = item['label']

        # 原始分数
        original_em_scores.append(exact_match(original_pred, label))
        original_f1_scores.append(f1_score(original_pred, label))
        original_reranker_scores.append(partial_match_score(original_pred, label, 0.5))

        # 清理后分数
        cleaned_em_scores.append(exact_match(cleaned_pred, label))
        cleaned_f1_scores.append(f1_score(cleaned_pred, label))
        cleaned_reranker_scores.append(partial_match_score(cleaned_pred, label, 0.5))

    # 计算平均分数
    original_results = {
        'EM': np.mean(original_em_scores),
        'F1': np.mean(original_f1_scores),
        'Reranker': np.mean(original_reranker_scores)
    }

    cleaned_results = {
        'EM': np.mean(cleaned_em_scores),
        'F1': np.mean(cleaned_f1_scores),
        'Reranker': np.mean(cleaned_reranker_scores)
    }

    # 打印对比结果
    print("\n=== 清理前后对比 ===")
    print(f"样本数量: {len(data)}")
    print(f"\n{'指标':<15} {'原始':<10} {'清理后':<10} {'改进':<10}")
    print("-" * 50)

    for metric in ['EM', 'F1', 'Reranker']:
        original = original_results[metric]
        cleaned = cleaned_results[metric]
        improvement = cleaned - original
        print(f"{metric:<15} {original:<10.4f} {cleaned:<10.4f} {improvement:>+.4f}")

    # 找出改进最明显的样本
    print("\n=== 改进最明显的样本 ===")
    improvements = []
    for i, item in enumerate(cleaned_data):
        original_f1 = f1_score(item['original_predict'], item['label'])
        cleaned_f1 = f1_score(item['predict'], item['label'])
        improvement = cleaned_f1 - original_f1
        if improvement > 0.1:  # 改进超过0.1的样本
            improvements.append((i, improvement, item))

    improvements.sort(key=lambda x: x[1], reverse=True)
    for i, (idx, improvement, item) in enumerate(improvements[:3]):
        print(f"\n样本 {idx} (改进 +{improvement:.3f}):")
        print(f"原始: {item['original_predict'][:100]}...")
        print(f"清理: {item['predict'][:100]}...")
        print(f"标签: {item['label'][:100]}...")

    # 保存清理后的数据
    if output_file:
        print(f"\n正在保存清理后的数据到: {output_file}")
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in cleaned_data:
                # 只保存清理后的predict，移除original_predict以节省空间
                save_item = {k: v for k, v in item.items() if k != 'original_predict'}
                f.write(json.dumps(save_item, ensure_ascii=False) + '\n')

    return original_results, cleaned_results

# 运行清理和评估
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/ultrachatvb/generated_predictions.jsonl"
    output_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/cleaned_predictionsvb.jsonl"

    original_results, cleaned_results = clean_and_evaluate(input_file, output_file)

正在加载数据...
加载了 548 条数据
正在清理预测结果...

--- 样本 1 清理对比 ---
原始: Tony the knowledge and authority rapport with the audience.
5. Benefits:: presenter offers an a offer or service being are are, emphasizing how it wor...
清理后: Tony the knowledge and authority rapport with the audience. 5. Benefits:: presenter offers an a offer or service being are are, emphasizing how it wor...

--- 样本 2 清理对比 ---
原始: Sure. You're right. We need to do a lot more work to make sure that we can provide good support for customers, and we need to do a lot more work to ma...
清理后: Sure. You're right. We need to do a lot more work to make sure that we can provide good support for customers, and need do a lot more work make that c...

--- 样本 3 清理对比 ---
原始: To make the string "This is a new {object} at {place}" into a Map, you can use the `Map` class in Java. Here's how you can do it:

1. First, you need ...
清理后: To make the string "This is a new {object} at {place}" into a Map, you can use the `Map` class in Java. Here's h

In [None]:
import json
import re
import string
from collections import Counter
import numpy as np
from typing import List, Dict, Any, Tuple

class SmartCleaner:
    """更智能的预测清理器"""

    @staticmethod
    def detect_severe_repetition(text: str) -> bool:
        """检测严重的重复问题"""
        words = text.split()
        if len(words) < 5:
            return False

        # 检查是否有词汇重复超过总词数的50%
        word_counts = Counter(words)
        max_count = max(word_counts.values()) if word_counts else 0
        repetition_ratio = max_count / len(words)

        return repetition_ratio > 0.5

    @staticmethod
    def is_garbage_output(text: str) -> bool:
        """判断是否为垃圾输出"""
        # 检查各种垃圾输出模式
        text_lower = text.lower().strip()

        # 完全重复的单词
        words = text.split()
        if len(words) > 10:
            unique_words = set(words)
            if len(unique_words) <= 3:  # 只有很少的不同词汇
                return True

        # 检查是否主要是重复字符
        if len(text) > 50:
            char_counts = Counter(text.replace(' ', '').replace('\n', ''))
            if char_counts and max(char_counts.values()) > len(text) * 0.4:
                return True

        # 检查是否包含明显的格式错误
        error_patterns = [
            r'^(\w{1,4})\1{10,}',  # 短词重复很多次
            r'(assistant|system|user)\s*\1',  # 系统词重复
        ]

        for pattern in error_patterns:
            if re.search(pattern, text_lower):
                return True

        return False

    @staticmethod
    def smart_repetition_fix(text: str) -> str:
        """智能修复重复问题"""
        # 修复开头的重复问题
        text = re.sub(r'^(\w+)\s*\1\s*(assistant|system)?\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 修复句子级别的重复
        sentences = re.split(r'([.!?]+)', text)
        cleaned_sentences = []
        seen_sentences = set()

        for i in range(0, len(sentences), 2):
            if i < len(sentences):
                sentence = sentences[i].strip()
                sentence_key = re.sub(r'\s+', ' ', sentence.lower())

                if sentence_key and sentence_key not in seen_sentences and len(sentence) > 10:
                    seen_sentences.add(sentence_key)
                    cleaned_sentences.append(sentence)
                    if i + 1 < len(sentences):
                        cleaned_sentences.append(sentences[i + 1])  # 保留标点

        result = ''.join(cleaned_sentences).strip()

        # 如果清理后太短，使用原文
        if len(result) < len(text) * 0.3:
            return text

        return result

    @staticmethod
    def conservative_clean(text: str) -> str:
        """保守的清理策略"""
        original_text = text

        # 只修复明显的问题
        # 1. 移除格式标记
        text = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', text, flags=re.DOTALL)

        # 2. 修复开头的重复词和格式错误
        text = re.sub(r'^(\w+)\s*\1\s*assistant\s*', r'\1 ', text, flags=re.IGNORECASE)
        text = re.sub(r'^(\w+)\s*assistant\s*', r'\1 ', text, flags=re.IGNORECASE)

        # 3. 移除standalone的assistant/system/user
        text = re.sub(r'\b(assistant|system|user)\s+', '', text, flags=re.IGNORECASE)

        # 4. 修复多余空格
        text = re.sub(r'\s+', ' ', text).strip()

        # 如果清理后明显变差，返回原文
        if len(text) < len(original_text) * 0.7:
            return original_text

        return text

    @classmethod
    def process_prediction(cls, text: str) -> Tuple[str, str]:
        """
        处理预测文本
        返回: (清理后的文本, 处理策略)
        """
        # 检查是否为严重的垃圾输出
        if cls.is_garbage_output(text):
            # 对于垃圾输出，尝试从中提取有意义的部分
            words = text.split()
            if words:
                # 尝试找到第一个正常的词开始
                meaningful_start = 0
                for i, word in enumerate(words):
                    if len(word) > 2 and word.lower() not in ['to', 'the', 'and', 'or', 'but']:
                        meaningful_start = i
                        break

                # 取前面一些正常词汇
                if meaningful_start < len(words) - 5:
                    reconstructed = ' '.join(words[meaningful_start:meaningful_start + 20])
                    return reconstructed, "garbage_reconstruction"

            return text, "garbage_kept"

        # 检查严重重复
        elif cls.detect_severe_repetition(text):
            cleaned = cls.smart_repetition_fix(text)
            return cleaned, "repetition_fix"

        # 保守清理
        else:
            cleaned = cls.conservative_clean(text)
            return cleaned, "conservative"

def advanced_evaluation(input_file: str, output_file: str = None):
    """高级评估和清理"""

    # 加载数据
    print("正在加载数据...")
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))

    print(f"加载了 {len(data)} 条数据")

    # 分析和清理
    print("正在分析和清理预测...")

    strategy_counts = Counter()
    cleaned_data = []
    severe_issues = []

    for i, item in enumerate(data):
        original_pred = item['predict']
        cleaned_pred, strategy = SmartCleaner.process_prediction(original_pred)

        strategy_counts[strategy] += 1

        # 记录严重问题的样本
        if strategy in ["garbage_reconstruction", "garbage_kept"]:
            severe_issues.append({
                'index': i,
                'strategy': strategy,
                'original': original_pred[:200] + "..." if len(original_pred) > 200 else original_pred,
                'cleaned': cleaned_pred[:200] + "..." if len(cleaned_pred) > 200 else cleaned_pred,
                'label': item['label'][:200] + "..." if len(item['label']) > 200 else item['label']
            })

        item_copy = item.copy()
        item_copy['original_predict'] = original_pred
        item_copy['predict'] = cleaned_pred
        item_copy['clean_strategy'] = strategy
        cleaned_data.append(item_copy)

    print(f"\n清理策略统计:")
    for strategy, count in strategy_counts.items():
        print(f"  {strategy}: {count} 个样本 ({count/len(data)*100:.1f}%)")

    # 显示严重问题样本
    if severe_issues:
        print(f"\n=== 发现 {len(severe_issues)} 个严重问题样本 ===")
        for i, issue in enumerate(severe_issues[:3]):
            print(f"\n样本 {issue['index']} ({issue['strategy']}):")
            print(f"原始: {issue['original']}")
            print(f"清理: {issue['cleaned']}")
            print(f"标签: {issue['label']}")

    # 重新评估
    print("\n正在重新计算评估指标...")

    def normalize_answer(s: str) -> str:
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def exact_match(prediction: str, ground_truth: str) -> float:
        return float(normalize_answer(prediction) == normalize_answer(ground_truth))

    def f1_score(prediction: str, ground_truth: str) -> float:
        pred_tokens = normalize_answer(prediction).split()
        truth_tokens = normalize_answer(ground_truth).split()

        if len(pred_tokens) == 0 and len(truth_tokens) == 0:
            return 1.0
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return 0.0

        common = Counter(pred_tokens) & Counter(truth_tokens)
        num_same = sum(common.values())

        if num_same == 0:
            return 0.0

        precision = num_same / len(pred_tokens)
        recall = num_same / len(truth_tokens)

        return (2 * precision * recall) / (precision + recall)

    def smart_reranker(prediction: str, ground_truth: str) -> float:
        """智能Reranker，使用多个标准"""
        # 标准1: F1分数高于阈值
        f1 = f1_score(prediction, ground_truth)
        if f1 >= 0.6:
            return 1.0

        # 标准2: 关键词匹配
        pred_words = set(normalize_answer(prediction).split())
        truth_words = set(normalize_answer(ground_truth).split())

        if len(truth_words) > 0:
            keyword_overlap = len(pred_words & truth_words) / len(truth_words)
            if keyword_overlap >= 0.4:
                return 1.0

        return 0.0

    # 计算分数
    strategies = ['conservative', 'repetition_fix', 'garbage_reconstruction', 'garbage_kept']
    results_by_strategy = {}

    for strategy in strategies:
        strategy_items = [item for item in cleaned_data if item['clean_strategy'] == strategy]
        if not strategy_items:
            continue

        em_scores = []
        f1_scores = []
        reranker_scores = []

        for item in strategy_items:
            pred = item['predict']
            label = item['label']

            em_scores.append(exact_match(pred, label))
            f1_scores.append(f1_score(pred, label))
            reranker_scores.append(smart_reranker(pred, label))

        results_by_strategy[strategy] = {
            'count': len(strategy_items),
            'EM': np.mean(em_scores),
            'F1': np.mean(f1_scores),
            'Reranker': np.mean(reranker_scores)
        }

    # 计算总体分数
    all_em_scores = []
    all_f1_scores = []
    all_reranker_scores = []
    all_original_f1_scores = []

    for item in cleaned_data:
        pred = item['predict']
        orig_pred = item['original_predict']
        label = item['label']

        all_em_scores.append(exact_match(pred, label))
        all_f1_scores.append(f1_score(pred, label))
        all_reranker_scores.append(smart_reranker(pred, label))
        all_original_f1_scores.append(f1_score(orig_pred, label))

    overall_results = {
        'EM': np.mean(all_em_scores),
        'F1': np.mean(all_f1_scores),
        'Reranker': np.mean(all_reranker_scores)
    }

    original_f1 = np.mean(all_original_f1_scores)

    # 打印结果
    print(f"\n=== 智能清理结果 ===")
    print(f"总体改进:")
    print(f"  原始F1: {original_f1:.4f}")
    print(f"  清理后F1: {overall_results['F1']:.4f} (改进: {overall_results['F1'] - original_f1:+.4f})")
    print(f"  EM: {overall_results['EM']:.4f}")
    print(f"  智能Reranker: {overall_results['Reranker']:.4f}")

    print(f"\n按策略分类的结果:")
    for strategy, results in results_by_strategy.items():
        print(f"  {strategy} ({results['count']} 样本):")
        print(f"    EM: {results['EM']:.4f}, F1: {results['F1']:.4f}, Reranker: {results['Reranker']:.4f}")

    # 保存结果
    if output_file:
        print(f"\n正在保存清理后的数据到: {output_file}")
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in cleaned_data:
                # 保存清理后的数据，包含策略信息
                save_item = {
                    'prompt': item['prompt'],
                    'predict': item['predict'],
                    'label': item['label'],
                    'clean_strategy': item['clean_strategy']
                }
                f.write(json.dumps(save_item, ensure_ascii=False) + '\n')

    return overall_results, results_by_strategy

# 运行智能清理
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl"
    output_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/smart_cleaned_predictions.jsonl"

    overall_results, strategy_results = advanced_evaluation(input_file, output_file)

In [None]:
import json
import time
import os
from openai import OpenAI

# 设置API Key
os.environ["OPENAI_API_KEY"] = ""
client = OpenAI()

def read_jsonl_data(path):
    """读取JSONL数据"""
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                item = json.loads(line)
                data.append({
                    'prompt': item.get('prompt', ''),
                    'prediction': item.get('predict', ''),
                    'reference': item.get('label', '')
                })
    return data

def gpt_judge(prediction, reference, prompt=""):
    """直接用GPT判断"""

    # 识别任务类型
    task_type = "general conversation"
    if prompt:
        prompt_lower = prompt.lower()
        if any(word in prompt_lower for word in ['code', 'programming', 'function', 'algorithm']):
            task_type = "coding"
        elif any(word in prompt_lower for word in ['explain', 'what is', 'describe', 'definition']):
            task_type = "explanation"
        elif any(word in prompt_lower for word in ['list', 'steps', 'how to', 'instruction']):
            task_type = "instruction"
        elif any(word in prompt_lower for word in ['analyze', 'compare', 'evaluate']):
            task_type = "analysis"

    judge_prompt = f"""You are evaluating an AI assistant's response for a {task_type} task.

Evaluation Guidelines:
- Focus on factual correctness and helpfulness
- Accept different but semantically equivalent answers
- Ignore formatting, style, and minor wording differences
- Consider the core message and information accuracy
- Be reasonably lenient - if the prediction conveys the same key information as the reference, it should be considered correct
- For explanations: focus on conceptual accuracy rather than exact phrasing
- For instructions: check if the steps achieve the same goal
- For analysis: evaluate logical reasoning and key conclusions

Important: Many responses may be worded differently but still be factually correct and helpful.

Output format: {{"score": 1, "reason": "explanation"}} for correct responses
               {{"score": 0, "reason": "explanation"}} for incorrect responses

Score 1: The prediction is factually correct and helpful (even if expressed differently)
Score 0: The prediction is factually wrong, misleading, or significantly unhelpful"""

    # 限制文本长度以节省token
    max_length = 1000
    pred_text = prediction[:max_length]
    ref_text = reference[:max_length]

    # 添加省略号如果被截断
    if len(prediction) > max_length:
        pred_text += " [truncated...]"
    if len(reference) > max_length:
        ref_text += " [truncated...]"

    user_message = f"""Task Type: {task_type}

AI Prediction:
{pred_text}

Reference Answer:
{ref_text}

Is the AI prediction factually correct and helpful compared to the reference? Output JSON only."""

    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            temperature=0.1,
            messages=[
                {"role": "system", "content": judge_prompt},
                {"role": "user", "content": user_message}
            ],
            max_tokens=150
        )

        content = response.choices[0].message.content.strip()

        # 解析JSON
        try:
            # 查找JSON部分
            start = content.find("{")
            end = content.rfind("}") + 1

            if start != -1 and end > start:
                json_text = content[start:end]
                result = json.loads(json_text)

                score = int(result.get("score", 0))
                reason = result.get("reason", "no reason provided")

                return score, reason
            else:
                # 如果没有找到JSON，尝试从文本中提取
                if "score" in content.lower():
                    if "1" in content and ("correct" in content.lower() or "accurate" in content.lower()):
                        return 1, "extracted_positive"
                    else:
                        return 0, "extracted_negative"
                return 0, f"parse_failed: {content}"

        except json.JSONDecodeError:
            # JSON解析失败，基于关键词判断
            content_lower = content.lower()
            if any(word in content_lower for word in ["correct", "accurate", "helpful", "good"]):
                return 1, f"keyword_positive: {content[:50]}"
            else:
                return 0, f"keyword_negative: {content[:50]}"

    except Exception as e:
        return 0, f"api_error: {str(e)}"

def evaluate_direct_gpt(input_path, sample_size=None, save_results=True):
    """直接用GPT评估所有样本"""

    print("=== 直接GPT判断评估 ===")
    print(f"输入文件: {input_path}")

    # 读取数据
    data = read_jsonl_data(input_path)
    print(f"加载了 {len(data)} 个样本")

    # 采样
    if sample_size and sample_size < len(data):
        import random
        data = random.sample(data, sample_size)
        print(f"采样 {sample_size} 个样本进行评估")

    results = []
    correct_count = 0
    error_count = 0

    print(f"\n开始处理 {len(data)} 个样本...")

    for i, item in enumerate(data):
        print(f"处理样本 {i+1}/{len(data)}", end="")

        pred = item['prediction']
        ref = item['reference']
        prompt = item['prompt']

        # 直接调用GPT判断
        score, reason = gpt_judge(pred, ref, prompt)

        if score == 1:
            correct_count += 1
            print(" ✓")
        else:
            print(f" ✗ ({reason[:30]}...)")

        if "api_error" in reason:
            error_count += 1

        results.append({
            'index': i,
            'prompt': prompt[:100] + "..." if len(prompt) > 100 else prompt,
            'prediction': pred[:200] + "..." if len(pred) > 200 else pred,
            'reference': ref[:200] + "..." if len(ref) > 200 else ref,
            'gpt_score': score,
            'gpt_reason': reason
        })

        # API限速
        time.sleep(0.5)

        # 每10个样本显示进度
        if (i + 1) % 10 == 0:
            current_accuracy = correct_count / (i + 1)
            print(f"  进度: {i+1}/{len(data)}, 当前准确率: {current_accuracy:.3f}")

    # 计算最终结果
    total_samples = len(results)
    accuracy = correct_count / total_samples

    print(f"\n=== 评估结果 ===")
    print(f"总样本数: {total_samples}")
    print(f"正确样本: {correct_count}")
    print(f"错误样本: {total_samples - correct_count}")
    print(f"API错误: {error_count}")
    print(f"最终准确率: {accuracy:.4f} ({accuracy*100:.2f}%)")

    # 显示一些示例
    print(f"\n=== 正确样本示例 ===")
    correct_examples = [r for r in results if r['gpt_score'] == 1][:2]
    for ex in correct_examples:
        print(f"预测: {ex['prediction']}")
        print(f"参考: {ex['reference']}")
        print(f"原因: {ex['gpt_reason']}")
        print()

    print(f"=== 错误样本示例 ===")
    incorrect_examples = [r for r in results if r['gpt_score'] == 0][:2]
    for ex in incorrect_examples:
        print(f"预测: {ex['prediction']}")
        print(f"参考: {ex['reference']}")
        print(f"原因: {ex['gpt_reason']}")
        print()

    # 保存结果
    if save_results:
        output_file = f"gpt_direct_evaluation_results.json"
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump({
                'summary': {
                    'total_samples': total_samples,
                    'correct_count': correct_count,
                    'accuracy': accuracy,
                    'api_errors': error_count
                },
                'detailed_results': results
            }, f, indent=2, ensure_ascii=False)

        print(f"\n结果已保存到: {output_file}")

    return results

def quick_test(input_path, n_samples=3):
    """快速测试GPT调用"""
    print("=== 快速GPT测试 ===")

    data = read_jsonl_data(input_path)
    if not data:
        print("无法读取数据")
        return

    # 测试前几个样本
    for i in range(min(n_samples, len(data))):
        item = data[i]
        print(f"\n--- 测试样本 {i+1} ---")
        print(f"预测: {item['prediction'][:100]}...")
        print(f"参考: {item['reference'][:100]}...")

        score, reason = gpt_judge(item['prediction'], item['reference'], item['prompt'])

        print(f"GPT评分: {score}")
        print(f"原因: {reason}")

        time.sleep(1)  # 短暂暂停

# 使用示例
if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/llama_saves/Qwen2-VL-2B/lora/sharegpttest/generated_predictions.jsonl"

    results = evaluate_direct_gpt(input_file, sample_size=None, save_results=True)