In [None]:
import json
from collections import defaultdict
import os
import pandas as pd
import docker
from tqdm import tqdm
from datasets import load_dataset
import argparse


In [None]:

def parse_args():
    """解析命令行参数"""
    parser = argparse.ArgumentParser(
        description="运行验证任务的配置参数",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter  # 自动显示默认值
    )
    
    # 路径参数
    parser.add_argument(
        "--root_dir", "-r",
        type=str,
        default='/root/projects/main_projects/logs/run_validation/',
        help="日志存储根目录"
    )
    
    # 模型参数
    parser.add_argument(
        "--hf_name", "-m",
        type=str,
        default='zengliangcs/SWE-Fixer-Ours-2',
        help="Hugging Face仓库名称"
    )
    
    # 数据参数
    parser.add_argument(
        "--split", "-s",
        type=str,
        choices=['train', 'validation', 'test'],
        default='train',
        help="数据集划分类型"
    )
    
    return parser.parse_args()

In [None]:
def get_validations(validation_logs, root_dir, test_type):
    total_instances_num = 0
    instance_with_report_num = 0
    root_dir = os.path.join(root_dir, test_type)
    for instance_id in os.listdir(root_dir):
        instance_id_dir = os.path.join(root_dir, instance_id)
        if os.path.isdir(instance_id_dir):
            total_instances_num += 1
            report_path = os.path.join(instance_id_dir, 'report.json')
            # 是否存在report.json
            if os.path.isfile(report_path):
                instance_with_report_num += 1
                try:
                    with open(report_path, 'r') as f:
                        data = json.load(f)
                    # 获取以子目录名为键的 resolved 值，默认为 False
                    # resolved = data.get(instance_id, {}).get('resolved', False)
                    tests_status =  data.get(instance_id, {}).get('tests_status', {})
                    if tests_status != {}:
                        pass_set = set(tests_status.get('PASS', []))
                        fail_set = set(tests_status.get('FAIL', []))
                        validation_logs[instance_id][f"{test_type}-PASS"] = pass_set
                        validation_logs[instance_id][f"{test_type}-FAIL"] = fail_set
                except (json.JSONDecodeError, IOError):
                    pass
    return total_instances_num, instance_with_report_num


In [None]:

def get_success_validation_data(root_dir):
    total_instances_num = {'gold': 0, 'empty': 0}
    instance_with_report_num = {'gold': 0, 'empty': 0}
    saved_num = 0
    validation_logs = defaultdict(lambda: defaultdict(set))
    success_validation_data = defaultdict(lambda: defaultdict(list))
    for run_ids in os.listdir(root_dir):
        print(f"正在筛选{run_ids=}")
        run_ids_dir = os.path.join(root_dir, run_ids)
        for test_type in ['gold', 'empty']:
            type_total_instances_num, type_instance_with_report_num = get_validations(validation_logs, run_ids_dir, test_type)
            total_instances_num[test_type] = type_total_instances_num
            instance_with_report_num[test_type] = type_instance_with_report_num
        for instance_id, tests_status in validation_logs.items():
            fail_to_pass = tests_status["gold-PASS"] & tests_status["empty-FAIL"]
            pass_to_pass = tests_status["gold-PASS"] & tests_status["empty-PASS"]
            # 只保存fail_to_pass不为空的instance
            if len(fail_to_pass) > 0:
                saved_num += 1
                success_validation_data[instance_id]['FAIL_TO_PASS'] = list(fail_to_pass)
                success_validation_data[instance_id]['PASS_TO_PASS'] = list(pass_to_pass)
    return success_validation_data, total_instances_num, instance_with_report_num, saved_num


In [None]:

def filter_dataset(hf_name, split, success_validation_data):
    dataset = load_dataset(hf_name, split=split)
    def process_data(example):
        instance_id = example['instance_id']
        example['FAIL_TO_PASS'] = success_validation_data[instance_id]['FAIL_TO_PASS']
        example['PASS_TO_PASS'] = success_validation_data[instance_id]['PASS_TO_PASS']
        return example

    # 过滤数据并添加字段
    validated_instances = dataset.filter(
        lambda x: x['instance_id'] in success_validation_data.keys()
    ).map(process_data)

    return validated_instances


In [None]:
args = parse_args()
print(f"当前配置：\n"
        f"日志目录：{args.root_dir}\n"
        f"模型名称：{args.hf_name}\n"
        f"数据划分：{args.split}")

success_validation_data, total_instances_num, instance_with_report_num, saved_num = \
    get_success_validation_data(args.root_dir)

print(f"instance总数：{total_instances_num=}")
print(f"生成report.json的instance总数：{instance_with_report_num=}")
print(f"保存的instance总数：{saved_num=}")

validated_instances = filter_dataset(args.hf_name, args.split, success_validation_data)
# 保存到本地
validated_instances.to_json(
    f"{args.hf_name}.jsonl", 
    orient="records", 
    lines=True,
    batch_size=10000,
    force_ascii=False
)


In [None]:
# Push validated instance images to Docker Hub
def push_validated_instances_images(
    instances, 
    namespace: str = "lycfight",
    instance_image_tag: str = "latest"
):
    client = docker.from_env()
    success = 0
    failed = 0
    
    print(f"Pushing {len(instances)} validated instance images to {namespace}...")
    with tqdm(total=len(instances)) as pbar:
        for instance in instances:
            instance_id = instance['instance_id']
            image_name = f"sweb.eval.x86_64.{instance_id.lower()}:{instance_image_tag}"
            new_image_name = f"{namespace}/{image_name}".replace("__", "_s_")
            
            try:
                image = client.images.get(image_name)
                image.tag(new_image_name)
                
                for line in client.images.push(new_image_name, stream=True, decode=True):
                    if 'error' in line:
                        raise Exception(line['error'])
                success += 1
                
            except Exception as e:
                print(f"Error pushing {image_name}: {e}")
                failed += 1
                
            pbar.update(1)
            pbar.set_postfix({
                "success": success,
                "failed": failed
            })
    
    print(f"Finished pushing images: {success} successful, {failed} failed")

In [None]:
# Push images for validated instances
push_validated_instances_images(validated_instances)