In [4]:
import json
from collections import defaultdict
import os
import pandas as pd
import docker
from tqdm import tqdm
from datasets import load_dataset
import argparse


  from .autonotebook import tqdm as notebook_tqdm


In [5]:

def parse_args():
    """解析命令行参数"""
    parser = argparse.ArgumentParser(
        description="运行验证任务的配置参数",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter  # 自动显示默认值
    )
    
    # 路径参数
    parser.add_argument(
        "--root_dir", "-r",
        type=str,
        default='logs/run_validation/',
        help="日志存储根目录"
    )
    
    # 模型参数
    parser.add_argument(
        "--hf_name", "-m",
        type=str,
        default='zengliangcs/SWE-Fixer-Ours-1',
        help="Hugging Face仓库名称"
    )
    # run id
    parser.add_argument(
        "--run_id",
        type=str,
        default='0329',
        help="运行id号"
    )
    # 数据参数
    parser.add_argument(
        "--split", "-s",
        type=str,
        choices=['train', 'validation', 'test'],
        default='train',
        help="数据集划分类型"
    )
    
    return parser.parse_known_args()[0]

In [6]:
def get_validations(validation_logs, root_dir, test_type):
    total_instances_num = 0
    instance_with_report_num = 0
    root_dir = os.path.join(root_dir, test_type)
    for instance_id in os.listdir(root_dir):
        instance_id_dir = os.path.join(root_dir, instance_id)
        if os.path.isdir(instance_id_dir):
            total_instances_num += 1
            report_path = os.path.join(instance_id_dir, 'report.json')
            # 是否存在report.json
            if os.path.isfile(report_path):
                instance_with_report_num += 1
                try:
                    with open(report_path, 'r') as f:
                        data = json.load(f)
                    # 获取以子目录名为键的 resolved 值，默认为 False
                    # resolved = data.get(instance_id, {}).get('resolved', False)
                    tests_status =  data.get(instance_id, {}).get('tests_status', {})
                    if tests_status != {}:
                        pass_set = set(tests_status.get('PASS', []))
                        fail_set = set(tests_status.get('FAIL', []))
                        validation_logs[instance_id][f"{test_type}-PASS"] = pass_set
                        validation_logs[instance_id][f"{test_type}-FAIL"] = fail_set
                except (json.JSONDecodeError, IOError):
                    pass
    return total_instances_num, instance_with_report_num


In [7]:

def get_success_validation_data(root_dir):
    total_instances_num = {'gold': 0, 'empty': 0}
    instance_with_report_num = {'gold': 0, 'empty': 0}
    saved_num = 0
    validation_logs = defaultdict(lambda: defaultdict(set))
    success_validation_data = defaultdict(lambda: defaultdict(list))
    for test_type in ['gold', 'empty']:
        type_total_instances_num, type_instance_with_report_num = get_validations(validation_logs, root_dir, test_type)
        total_instances_num[test_type] = type_total_instances_num
        instance_with_report_num[test_type] = type_instance_with_report_num
    for instance_id, tests_status in validation_logs.items():
        fail_to_pass = tests_status["gold-PASS"] & tests_status["empty-FAIL"]
        pass_to_pass = tests_status["gold-PASS"] & tests_status["empty-PASS"]
        # 只保存fail_to_pass不为空的instance
        if len(fail_to_pass) > 0:
            saved_num += 1
            success_validation_data[instance_id]['FAIL_TO_PASS'] = list(fail_to_pass)
            success_validation_data[instance_id]['PASS_TO_PASS'] = list(pass_to_pass)
    return success_validation_data, total_instances_num, instance_with_report_num, saved_num


In [8]:

def filter_dataset(hf_name, split, success_validation_data):
    dataset = load_dataset(hf_name, split=split)
    def process_data(example):
        instance_id = example['instance_id']
        example['FAIL_TO_PASS'] = success_validation_data[instance_id]['FAIL_TO_PASS']
        example['PASS_TO_PASS'] = success_validation_data[instance_id]['PASS_TO_PASS']
        return example

    # 过滤数据并添加字段
    validated_instances = dataset.filter(
        lambda x: x['instance_id'] in success_validation_data.keys()
    ).map(process_data)

    return validated_instances


In [None]:
args = parse_args()
print(f"当前配置：\n"
        f"日志目录：{args.root_dir}\n"
        f"运行ID：{args.run_id}\n"
        f"仓库名称：{args.hf_name}\n"
        f"数据划分：{args.split}")

root_dir = os.path.join(args.root_dir, args.run_id)
success_validation_data, total_instances_num, instance_with_report_num, saved_num = \
    get_success_validation_data(root_dir)

print(f"instance总数：{total_instances_num=}")
print(f"生成report.json的instance总数：{instance_with_report_num=}")
print(f"保存的instance总数：{saved_num=}")

validated_instances = filter_dataset(args.hf_name, args.split, success_validation_data)
# 保存到本地
validated_instances.to_json(
    f"{args.hf_name}.jsonl", 
    orient="records", 
    lines=True,
    batch_size=10000,
    force_ascii=False
)


当前配置：
日志目录：logs/run_validation/
运行ID：0329
模型名称：zengliangcs/SWE-Fixer-Ours-1
数据划分：train


instance总数：total_instances_num={'gold': 7780, 'empty': 7770}
生成report.json的instance总数：instance_with_report_num={'gold': 4604, 'empty': 4647}
保存的instance总数：saved_num=552


Filter: 100%|██████████| 7780/7780 [00:00<00:00, 71624.80 examples/s]
Map: 100%|██████████| 552/552 [00:00<00:00, 8470.61 examples/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  5.05ba/s]


10715960