In [1]:
import re
from collections import defaultdict
from tqdm import trange
def extract_answers(text_list):
    """改进版答案提取函数，增加容错处理"""
    pattern = re.compile(
        r'Answer:\s*(.+?)(?=\[Answer Completed\]|<\|endoftext\|>|$)', 
        flags=re.DOTALL
    )
    return [match.group(1).strip() if (match := pattern.search(text)) else "No Answer Extracted" 
            for text in text_list]

def normalize_string(s):
    """字符串标准化：去除前后空格+转小写+移除标点"""
    return s.strip().lower().translate(str.maketrans('', '', ',.!?'))


In [3]:

def evaluate_responses(dataset, model, tokenizer, batch_size=8):
    """
    批量处理评估函数
    
    参数：
        dataset: 包含prompt和completion的数据集
        model: 已加载的生成模型
        tokenizer: 对应的tokenizer
        batch_size: 处理批次大小
        
    返回：
        results: 包含详细结果的字典列表
        metrics: 包含统计指标的字典
    """
    # 初始化数据结构
    results = []
    counter = defaultdict(int)
    
    # 按批次处理
    for i in trange(0, len(dataset), batch_size):
        batch = dataset[i:i+batch_size]
        with torch.no_grad():
            # 模型生成
            inputs = tokenizer(
                batch['prompt'], 
                return_tensors='pt', 
                padding=True, 
                truncation=True,
                max_length=512  # 根据模型调整
            ).to(model.device)
            
            outputs = model.generate(
                **inputs,
                max_new_tokens=20,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id  # 确保与padding策略一致
            )
            
            # 解码并提取答案
            responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        pred_answers = extract_answers(responses)
        
        # 处理每个样本
        for j in trange(len(batch['prompt'])):
            # 获取参考答案
            true_answer = normalize_string(batch['completion'][j])
            pred_answer = normalize_string(pred_answers[j])
            
            # 判断正确性
            is_correct = (pred_answer == true_answer) if pred_answer != "noanswerextracted" else False
            
            # 记录结果
            record = {
                "prompt": batch['prompt'][j],
                "true_answer": true_answer,
                "pred_answer": pred_answer,
                "is_correct": is_correct,
                "raw_response": responses[j]
            }
            results.append(record)
            
            # 统计计数
            counter['total'] += 1
            if is_correct:
                counter['correct'] += 1
            if "noanswerextracted" in pred_answer:
                counter['no_answer'] += 1

    # 计算指标
    metrics = {
        "accuracy": counter['correct'] / counter['total'] if counter['total'] else 0,
        "total_samples": counter['total'],
        "correct_count": counter['correct'],
        "no_answer_rate": counter['no_answer'] / counter['total'] if counter['total'] else 0
    }
    
    return results, metrics

In [2]:
from transformers import AutoModelForCausalLM
import torch
from transformers import AutoTokenizer
model_paths = []
model_paths += ['/data0/leileqi/Paper-reproduction/pretrain/pretrained_models_NTP_from_scratch_No-RandomShuffleEveryEpoch-SFT/epoch_0']
model_paths += ['/data0/leileqi/Paper-reproduction/finetune/pretrained_models_NTP_from_scratch_ShuffleDataEveryEpoch-SFT/epoch_0']
model_paths

  from .autonotebook import tqdm as notebook_tqdm


['/data0/leileqi/Paper-reproduction/pretrain/pretrained_models_NTP_from_scratch_No-RandomShuffleEveryEpoch-SFT/epoch_0',
 '/data0/leileqi/Paper-reproduction/finetune/pretrained_models_NTP_from_scratch_ShuffleDataEveryEpoch-SFT/epoch_0']

In [3]:

for model_path in model_paths:
    print(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda:0')
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    import gc
    gc.collect()
    torch.cuda.empty_cache()

    data_path = '/data0/leileqi/Paper-reproduction/finetune/test_half_template_deduped_20%subset.jsonl'

    from datasets import load_dataset
    ds = load_dataset('json', data_files=data_path)['train']
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token



    results, metrics = evaluate_responses(ds, model, tokenizer, batch_size=1024)

    print(f"\n评估结果：")
    print(f"准确率: {metrics['accuracy']:.2%}")
    print(f"总样本数: {metrics['total_samples']}")
    print(f"未提取答案比例: {metrics['no_answer_rate']:.2%}")

    # 查看前3个样本的详细信息
    print("\n示例样本记录：")
    for r in results[:3]:
        print(f"True: {r['true_answer']} | Pred: {r['pred_answer']} | Correct: {r['is_correct']}")
    import json
    from datetime import datetime
    with open('./eval_log.jsonl', 'a') as f:
        data = dict(
            time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            model_path=model_path,
            metrics=metrics,
            data_path = data_path,
            results=results
        )
        f.write(json.dumps(data, ensure_ascii=False,) + '\n')
        
    print("评估结果已保存到eval_log.jsonl")

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48


/data1/leileqi/MTPhysics/pretrain/pretrained_models_MTP_from_scratch_No-ShuffleDataEveryEpoch-SFT/epoch_0


100%|██████████| 1024/1024 [00:00<00:00, 446045.00it/s]
100%|██████████| 1024/1024 [00:00<00:00, 453725.68it/s]
100%|██████████| 1024/1024 [00:00<00:00, 453629.84it/s]
100%|██████████| 1024/1024 [00:00<00:00, 465074.96it/s]
100%|██████████| 1024/1024 [00:00<00:00, 468627.09it/s]
100%|██████████| 1024/1024 [00:00<00:00, 459748.16it/s]
100%|██████████| 1024/1024 [00:00<00:00, 456814.22it/s]
100%|██████████| 1024/1024 [00:00<00:00, 477430.78it/s]
100%|██████████| 1024/1024 [00:00<00:00, 467098.13it/s]
100%|██████████| 1024/1024 [00:00<00:00, 463969.68it/s]
100%|██████████| 1024/1024 [00:00<00:00, 467861.36it/s]
100%|██████████| 1024/1024 [00:00<00:00, 467098.13it/s]
100%|██████████| 1024/1024 [00:00<00:00, 468831.71it/s]
100%|██████████| 1024/1024 [00:00<00:00, 466793.53it/s]
100%|██████████| 1024/1024 [00:00<00:00, 471818.88it/s]
100%|██████████| 1024/1024 [00:00<00:00, 468678.23it/s]
100%|██████████| 1024/1024 [00:00<00:00, 463069.25it/s]
100%|██████████| 1024/1024 [00:00<00:00, 463669.


评估结果：
准确率: 60.78%
总样本数: 213378
未提取答案比例: 0.00%

示例样本记录：
True: february 27 1993 | Pred: november 23 1982 | Correct: False
True: east matthew | Pred: new michael | Correct: False
True: university of california los angeles | Pred: university of toronto | Correct: False
评估结果已保存到eval_log.jsonl
/data1/leileqi/MTPhysics/finetune/pretrained_models_MTP_from_scratch_ShuffleDataEveryEpoch-SFT/epoch_0


100%|██████████| 1024/1024 [00:00<00:00, 412501.66it/s]
100%|██████████| 1024/1024 [00:00<00:00, 432219.71it/s]
100%|██████████| 1024/1024 [00:00<00:00, 449311.36it/s]
100%|██████████| 1024/1024 [00:00<00:00, 461229.31it/s]
100%|██████████| 1024/1024 [00:00<00:00, 464521.66it/s]
100%|██████████| 1024/1024 [00:00<00:00, 457836.83it/s]
100%|██████████| 1024/1024 [00:00<00:00, 462222.05it/s]
100%|██████████| 1024/1024 [00:00<00:00, 451057.27it/s]
100%|██████████| 1024/1024 [00:00<00:00, 455361.25it/s]
100%|██████████| 1024/1024 [00:00<00:00, 443740.81it/s]
100%|██████████| 1024/1024 [00:00<00:00, 454878.98it/s]
100%|██████████| 1024/1024 [00:00<00:00, 452435.19it/s]
100%|██████████| 1024/1024 [00:00<00:00, 462470.91it/s]
100%|██████████| 1024/1024 [00:00<00:00, 450773.23it/s]
100%|██████████| 1024/1024 [00:00<00:00, 457983.29it/s]
100%|██████████| 1024/1024 [00:00<00:00, 453007.84it/s]
100%|██████████| 1024/1024 [00:00<00:00, 446045.00it/s]
100%|██████████| 1024/1024 [00:00<00:00, 460339.

/data0/leileqi/Paper-reproduction/pretrain/pretrained_models_NTP_from_scratch_No-RandomShuffleEveryEpoch-SFT/epoch_0


The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48
100%|██████████| 1024/1024 [00:00<00:00, 460882.85it/s]
100%|██████████| 1024/1024 [00:00<00:00, 473326.79it/s]
100%|██████████| 1024/1024 [00:00<00:00, 466793.53it/s]
100%|██████████| 1024/1024 [00:00<00:00, 479349.03it/s]
100%|██████████| 1024/1024 [00:00<00:00, 474844.37it/s]
100%|██████████| 1024/1024 [00:00<00:00, 477962.09it/s]
100%|██████████| 1024/1024 [00:00<00:00, 482147.20it/s]
100%|██████████| 1024/1024 [00:00<00:00, 489957.48it/s]
100%|██████████| 1024/1024 [00:00<00:00, 484267.37it/s]
100%|██████████| 1024/1024 [00:00<00:00, 478814.64it/s]
100%|██████████| 1024/1024 [00:00<00:00, 484540.53it/s]
100%|██████████| 1024/1024 [00:00<00:00, 478547.89it/s]
100%|██████████| 1024/1024 [00:00<00:00, 484431.23it/s]
100%|██████████| 1024/1024 [00:00<00:00, 482201.34it/s]
100%|██████████| 1024/1024 [00:00<00


评估结果：
准确率: 0.98%
总样本数: 213378
未提取答案比例: 0.00%

示例样本记录：
True: february 27 1993 | Pred: june 14 1984 | Correct: False
True: east matthew | Pred: west david | Correct: False
True: university of california los angeles | Pred: university of california berkeley | Correct: False
评估结果已保存到eval_log.jsonl
/data0/leileqi/Paper-reproduction/finetune/pretrained_models_NTP_from_scratch_ShuffleDataEveryEpoch-SFT/epoch_0


100%|██████████| 1024/1024 [00:00<00:00, 419922.50it/s]
100%|██████████| 1024/1024 [00:00<00:00, 430098.87it/s]
100%|██████████| 1024/1024 [00:00<00:00, 454975.35it/s]
100%|██████████| 1024/1024 [00:00<00:00, 458766.00it/s]
100%|██████████| 1024/1024 [00:00<00:00, 464873.61it/s]
100%|██████████| 1024/1024 [00:00<00:00, 462421.11it/s]
100%|██████████| 1024/1024 [00:00<00:00, 463069.25it/s]
100%|██████████| 1024/1024 [00:00<00:00, 466033.78it/s]
100%|██████████| 1024/1024 [00:00<00:00, 466844.27it/s]
100%|██████████| 1024/1024 [00:00<00:00, 463769.28it/s]
100%|██████████| 1024/1024 [00:00<00:00, 465175.71it/s]
100%|██████████| 1024/1024 [00:00<00:00, 458472.17it/s]
100%|██████████| 1024/1024 [00:00<00:00, 471352.86it/s]
100%|██████████| 1024/1024 [00:00<00:00, 464773.00it/s]
100%|██████████| 1024/1024 [00:00<00:00, 462869.63it/s]
100%|██████████| 1024/1024 [00:00<00:00, 460882.85it/s]
100%|██████████| 1024/1024 [00:00<00:00, 468422.65it/s]
100%|██████████| 1024/1024 [00:00<00:00, 463619.


评估结果：
准确率: 38.02%
总样本数: 213378
未提取答案比例: 0.00%

示例样本记录：
True: february 27 1993 | Pred: february 27 1999 | Correct: False
True: east matthew | Pred: east matthewmouth | Correct: False
True: university of california los angeles | Pred: university of california san diego | Correct: False

评估结果：
准确率: 96.36%
总样本数: 213378
未提取答案比例: 0.00%

示例样本记录：
True: february 27 1993 | Pred: february 27 1993 | Correct: True
True: east matthew | Pred: east matthew | Correct: True
True: university of california los angeles | Pred: university of california los angeles | Correct: True
评估结果已保存到eval_log.jsonl
