mkdir dataset
cd dataset
git clone https://github.com/newfacade/LeetCodeDataset.git
mkdir LeetCodeDataset_postprocessed

# 获取训练的dataset

In [2]:
import os
import json
import gzip
from pathlib import Path
from typing import List

def read_jsonl(filename):
    """读取JSONL文件，每行解析为一个JSON对象"""
    filename = str(filename)  # 兼容 Path
    if filename.endswith('.gz'):
        with gzip.open(filename, 'rt', encoding='utf-8') as f:
            return [json.loads(line) for line in f]
    else:
        with open(filename, 'r', encoding='utf-8') as f:
            return [json.loads(line) for line in f]

def write_jsonl(filename, data: List[dict]):
    """将一系列 JSON 对象写入到 .jsonl 文件中"""
    with open(filename, 'w', encoding='utf-8') as f:
        for obj in data:
            f.write(json.dumps(obj, ensure_ascii=False) + '\n')

def get_problem_file(version: str, split: str) -> Path:
    """获取指定版本和分割的问题文件路径"""
    dataset_path = Path('dataset/LeetCodeDataset/data')
    filename = f"LeetCodeDataset-{version}-{split}.jsonl.gz"
    return dataset_path / filename

def get_output_file(version: str, split: str) -> Path:
    """获取输出文件路径"""
    dataset_path = Path('dataset/LeetCodeDataset_postprocessed')
    filename = f"LeetCodeDataset-{version}-{split}.jsonl"
    return dataset_path / filename

def main():
    VERSION = "v0.3.1"
    SPLIT = "test"
    problem_file = get_problem_file(VERSION, SPLIT)
    output_file = get_output_file(VERSION, SPLIT)

    print(f"正在读取文件: {problem_file}")
    problems = read_jsonl(problem_file)
    print(f"总共有 {len(problems)} 道题目")

    print(f"正在写入文件: {output_file}")
    write_jsonl(output_file, problems)
    print("写入完成。")

if __name__ == "__main__":
    main()


正在读取文件: dataset/LeetCodeDataset/data/LeetCodeDataset-v0.3.1-test.jsonl.gz
总共有 228 道题目
正在写入文件: dataset/LeetCodeDataset_postprocessed/LeetCodeDataset-v0.3.1-test.jsonl
写入完成。


# 增加unittest字段 
-- 适配我们的测试docker

In [3]:
import os
import json
import gzip
import ast
import re
from pathlib import Path
from typing import List, Tuple

def read_jsonl(filename):
    filename = str(filename)
    if filename.endswith('.gz'):
        with gzip.open(filename, 'rt', encoding='utf-8') as f:
            return [json.loads(line) for line in f]
    else:
        with open(filename, 'r', encoding='utf-8') as f:
            return [json.loads(line) for line in f]

def write_jsonl(filename, data: List[dict]):
    with open(filename, 'w', encoding='utf-8') as f:
        for obj in data:
            f.write(json.dumps(obj, ensure_ascii=False) + '\n')

def get_problem_file(version: str, split: str) -> Path:
    dataset_path = Path('dataset/LeetCodeDataset_postprocessed')
    filename = f"LeetCodeDataset-{version}-{split}.jsonl"
    return dataset_path / filename

def get_output_file(version: str, split: str) -> Path:
    dataset_path = Path('dataset/LeetCodeDataset_postprocessed')
    filename = f"LeetCodeDataset-{version}-{split}.jsonl"
    return dataset_path / filename

def split_top_level_assignments(input_str: str) -> List[str]:
    result = []
    current = []
    depth = 0
    in_str = False
    str_char = ''
    i = 0
    while i < len(input_str):
        c = input_str[i]
        if in_str:
            current.append(c)
            if c == str_char and input_str[i - 1] != '\\':
                in_str = False
            i += 1
            continue
        if c in ('\"', "'"):
            in_str = True
            str_char = c
            current.append(c)
            i += 1
            continue
        if c in '[{(':
            depth += 1
        elif c in ']})':
            depth -= 1
        if c == ',' and depth == 0 and not in_str:
            result.append(''.join(current).strip())
            current = []
            i += 1
            continue
        current.append(c)
        i += 1
    if current:
        result.append(''.join(current).strip())
    return result

def default_json_encoder(obj):
    if obj is ...:
        return "<ELLIPSIS>"
    raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")

def detect_output_type_from_tests(test_code: str) -> str:
    if "is_same_list" in test_code:
        return "list_node"
    elif "is_same_tree" in test_code:
        return "tree_node"
    return "general"

def generate_unittests(problem: dict, failed_cases: List[Tuple[str, str, str]]) -> List[dict]:
    if 'input_output' not in problem:
        return []

    unittests = []
    qid = str(problem.get('question_id') or problem.get('task_id') or "UNKNOWN")
    test_type = detect_output_type_from_tests(problem.get("test", ""))

    for io_pair in problem['input_output']:
        input_str = io_pair['input']
        expected_output = io_pair['output']
        input_dict = {}

        try:
            input_str = (
                input_str.replace("null", "None")
                         .replace("true", "True")
                         .replace("false", "False")
            )

            assign_lines = split_top_level_assignments(input_str)
            cleaned_lines = []
            for line in assign_lines:
                line = line.strip()
                if line.endswith('"') and line.count('"') % 2 != 0:
                    line = line[:-1]
                cleaned_lines.append(line)

            assign_block = '\n'.join(cleaned_lines)
            local_env = {}
            exec(assign_block, {}, local_env)
            input_dict = local_env

            input_json = json.dumps(input_dict, default=default_json_encoder) + "\n"
            unittest = {
                "input": input_json,
                "output": [{
                    "type": test_type,
                    "value": expected_output
                }]
            }
            unittests.append(unittest)

        except Exception as e:
            print(f"\n解析失败: question_id/task_id: {qid}")
            print(f"Input: {input_str}")
            print(f"错误: {e}")
            failed_cases.append((qid, input_str, str(e)))
            continue

    return unittests

def add_unittests_to_problems(problems: List[dict]) -> Tuple[List[dict], List[Tuple[str, str, str]]]:
    failed_cases = []
    for problem in problems:
        unittests = generate_unittests(problem, failed_cases)
        if unittests:
            problem['unittests'] = unittests
    return problems, failed_cases

def main():
    VERSION = "v0.3.1"
    SPLIT = "test"
    problem_file = get_problem_file(VERSION, SPLIT)
    output_file = get_output_file(VERSION, SPLIT)
    failed_log_file = output_file.with_name(output_file.stem + "_failed.jsonl")

    print(f"正在读取文件: {problem_file}")
    problems = read_jsonl(problem_file)
    print(f"总共有 {len(problems)} 道题目")

    print("正在添加unittests字段...")
    problems, failed_cases = add_unittests_to_problems(problems)

    print(f"正在写入文件: {output_file}")
    write_jsonl(output_file, problems)
    print("写入完成。")

    print("\n前两个样本的 unittests 字段：")
    for i, p in enumerate(problems[:2]):
        print(f"\n=== 样本 {i+1} ===")
        print(json.dumps(p.get("unittests", []), indent=2, ensure_ascii=False))

    print("\n========== 解析失败样本报告 ==========")
    print(f"共计失败样本数: {len(failed_cases)}")
    for i, (qid, input_str, err) in enumerate(failed_cases[:20]):
        print(f"\n[{i+1}] question_id: {qid}\nInput: {input_str}\nError: {err}")

    if failed_cases:
        failed_ids = {qid for qid, _, _ in failed_cases}
        print(f"\n总共有 {len(failed_ids)} 个不同的 task 出现了解析失败。")

        print(f"\n正在写入失败样本日志到: {failed_log_file}")
        write_jsonl(failed_log_file, [
            {"question_id": qid, "input": input_str, "error": err}
            for qid, input_str, err in failed_cases
        ])
        print("失败样本日志写入完成。")

        failed_problems = [p for p in problems if str(p.get("question_id") or p.get("task_id")) in failed_ids]
        failed_full_json_file = output_file.with_name(output_file.stem + "_failed_full.json")
        with open(failed_full_json_file, "w", encoding="utf-8") as f:
            json.dump(failed_problems, f, ensure_ascii=False, indent=2)
        print(f"完整失败样本已保存为 JSON: {failed_full_json_file}")

if __name__ == "__main__":
    main()


正在读取文件: dataset/LeetCodeDataset_postprocessed/LeetCodeDataset-v0.3.1-test.jsonl
总共有 228 道题目
正在添加unittests字段...

解析失败: question_id/task_id: 3257
Input: board
错误: name 'board' is not defined

解析失败: question_id/task_id: 3265
Input: nums = [101010, 101001, 110010, 100110, 100101, 111000, 000111]
错误: leading zeros in decimal integer literals are not permitted; use an 0o prefix for octal integers (<string>, line 1)

解析失败: question_id/task_id: 3267
Input: nums = [1010101, 0101010, 1001001, 0010010, 1100110, 0110110]
错误: leading zeros in decimal integer literals are not permitted; use an 0o prefix for octal integers (<string>, line 1)

解析失败: question_id/task_id: 3267
Input: nums = [101010, 110101, 010101, 100100]
错误: leading zeros in decimal integer literals are not permitted; use an 0o prefix for octal integers (<string>, line 1)

解析失败: question_id/task_id: 3267
Input: nums = [101010, 010101, 101001, 010110, 100110, 011010, 110010]
错误: leading zeros in decimal integer literals are not permitt

# 增加preCodeSegment字段
-- 为了接收测试用例的输入，然后输出tree/list/str格式的答案。

In [4]:
import json
import os
preCodeSegment = """
import json
import sys
import random
import functools
import collections
import string
import math
import datetime
from typing import *
from functools import *
from collections import *
from itertools import *
from heapq import *
from bisect import *
from string import *
from operator import *
from math import *

inf = float('inf')

class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def list_node(values: list):
    if not values:
        return None
    head = ListNode(values[0])
    p = head
    for val in values[1:]:
        node = ListNode(val)
        p.next = node
        p = node
    return head

def print_list(head: ListNode) -> None:
    result = []
    while head:
        result.append(head.val)
        head = head.next
    print(result)
    
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def tree_node(values: list):
    if not values:
        return None
    root = TreeNode(values[0])
    i = 1
    queue = deque()
    queue.append(root)
    while queue:
        node = queue.popleft()
        if i < len(values) and values[i] is not None:
            node.left = TreeNode(values[i])
            queue.append(node.left)
        i += 1
        if i < len(values) and values[i] is not None:
            node.right = TreeNode(values[i])
            queue.append(node.right)
        i += 1
    return root

def print_tree(root: TreeNode) -> None:
    if not root:
        print([])
        return

    result = []
    queue = deque([root])

    while queue:
        node = queue.popleft()
        if node:
            result.append(node.val)
            queue.append(node.left)
            queue.append(node.right)
        else:
            result.append(None)

    while result and result[-1] is None:
        result.pop()

    print(result)

"""
input_file = "dataset/LeetCodeDataset_postprocessed/LeetCodeDataset-v0.3.1-test.jsonl"
temp_output = input_file + ".tmp"

with open(input_file, "r", encoding="utf-8") as fin, open(temp_output, "w", encoding="utf-8") as fout:
    for line in fin:
        data = json.loads(line)
        data["preCodeSegment"] = preCodeSegment
        fout.write(json.dumps(data, ensure_ascii=False) + "\n")

# Replace the original file with the updated one
os.replace(temp_output, input_file)

# 增加postCodeSegment字段
-- 为了接收测试用例的输入，然后输出tree/list/str格式的答案。

In [5]:
import json
from collections import OrderedDict

# 数据路径
input_file = "dataset/LeetCodeDataset_postprocessed/LeetCodeDataset-v0.3.1-test.jsonl"

# 按顺序定义要保留和输出的字段
fields_order = [
    "task_id",
    "question_id",
    "difficulty",
    "tags",
    "estimated_date",
    "entry_point",
    "query",
    "response",
    "completion",
    "preCodeSegment",
    "postCodeSegment",
    "unittests",
    
]

def generate_post_code(entry_point: str, unittest_type: str) -> str:
    fixed_lines = [
        "if __name__ == '__main__':",
        "    data = json.loads(sys.stdin.read())",
        f"    result = {entry_point}(**data)"
    ]

    if unittest_type == "list_node":
        fixed_lines.append("    print_list(result)")
    elif unittest_type == "tree_node":
        fixed_lines.append("    print_tree(result)")
    else:  # default to general
        fixed_lines.append("    print(result)")

    return "\n".join(fixed_lines)

# 读取、处理、写入
filtered_data = []

with open(input_file, 'r', encoding='utf-8') as f:
    for line in f:
        item = json.loads(line)

        # 获取unittest type（确保字段存在并且至少有一个用例）
        unittest_type = item.get("unittests", [{}])[0].get("output", [{}])[0].get("type", "---")
        # 生成 postCodeSegment
        entry_point = item.get("entry_point", "unknown_function")
        post_code = generate_post_code(entry_point, unittest_type)

        # 构建有序字典
        ordered_item = OrderedDict()
        for key in fields_order:
            if key == "postCodeSegment":
                ordered_item[key] = post_code
            elif key in item:
                ordered_item[key] = item[key]
            else:
                ordered_item[key] = None  # 可根据需求替换成 "" 或其他默认值

        filtered_data.append(ordered_item)

# 写回原文件（覆盖）
with open(input_file, 'w', encoding='utf-8') as f:
    for item in filtered_data:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

print("数据处理完成，postCodeSegment 已生成并字段顺序保留。")


数据处理完成，postCodeSegment 已生成并字段顺序保留。


# 把query字段名称换成prompt

In [1]:
import json

input_file = "dataset/LeetCodeDataset_postprocessed/LeetCodeDataset-v0.3.1-test.jsonl"
output_lines = []

with open(input_file, 'r', encoding='utf-8') as f:
    for line in f:
        item = json.loads(line)
        # 如果有 query 字段，则重命名为 prompt
        if "query" in item:
            item["prompt"] = item.pop("query")
        # 删除 response 字段的内容
        if "response" in item:
            del item["response"]
        output_lines.append(json.dumps(item, ensure_ascii=False))

# 覆盖写回原文件
with open(input_file, 'w', encoding='utf-8') as f:
    for line in output_lines:
        f.write(line + '\n')

print("字段 query 已成功重命名为 prompt，response 字段已被删除。")

字段 query 已成功重命名为 prompt，response 字段已被删除。


# 测试

In [7]:

import requests
import json
from pathlib import Path
from collections import deque


def extract_code_from_completion(completion: str) -> str:
    """从 markdown 风格的字符串中提取代码块"""
    try:
        if not isinstance(completion, str):
            return ""
        code_blocks = re.findall(r"```(?:[a-zA-Z]*\n)?([\s\S]*?)```", completion)
        return "\n".join(code_blocks).strip() if code_blocks else ""
    except Exception:
        return ""

def main():
    question_id = 105
    jsonl_path = Path("dataset/LeetCodeDataset_postprocessed/LeetCodeDataset-v0.3.1-test.jsonl")

    raw_response = ""
    pre_code = ""
    post_code = ""
    unittests = []

    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            obj = json.loads(line)
            if str(obj.get("question_id")) != str(question_id):
                continue

            raw_response = obj.get("response", "")
            pre_code = obj.get("preCodeSegment", "")
            post_code = obj.get("postCodeSegment", "")

            for ut in obj.get("unittests", []):
                try:
                    evaluated_input = json.loads(ut["input"])
                except Exception as e:
                    evaluated_input = f"<Invalid input: {e}>"
                unittests.append({
                    "input": evaluated_input,
                    "output": ut["output"]
                })
            break  # 找到目标样本，结束循环

    # 提取响应中的代码段
    response_code = extract_code_from_completion(raw_response)

    # 拼接完整测试代码
    full_test_code = "\n\n".join([pre_code, response_code, post_code])
    

    test_data = {
        "language": "Python 3",
        "source_code": full_test_code,
    #     "unittests": [
    #     {
    #         "input": json.dumps({"nums": [-100, -50, -25, -10, -3, 0, 5, 9, 20, 50, 75, 100]}) + "\n",
    #         "output": [
    #             {
    #                 "type": "tree_node",
    #                 "value": [0, -25, 20, -100, -10, 5, 75, None, -50, None, -3, None, 9, 50, 100]  # 使用列表表示树的结构
    #             }
    #         ]
    #     }
    # ],
        
        "unittests": unittests,
        
        "block_network": True,
        "stop_on_first_fail": False,
        "use_sanitizer": False
    }

    try:
        response = requests.post(
            "http://localhost:5000/api/execute_code",
            json=test_data,
            headers={"Content-Type": "application/json"}
        )

        if response.status_code == 200:
            result = response.json()
            print("测试结果:")
            print(json.dumps(result, indent=2, ensure_ascii=False))
            print(result)

        else:
            print(f"错误: HTTP {response.status_code}")
            print(response.text)

    except requests.exceptions.ConnectionError:
        print("连接错误: 无法连接到服务器，请确保服务正在运行")
    except Exception as e:
        print(f"发生错误: {str(e)}")

if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'requests'

In [8]:
import tracemalloc

def func_a():
    a = [i for i in range(10000)]

def func_b():
    b = {"key" + str(i): i for i in range(10000)}

tracemalloc.start()

func_a()
func_b()

snapshot = tracemalloc.take_snapshot()

# 按代码行统计内存分配
top_stats = snapshot.statistics('lineno')

print("Top 10 lines by memory usage:")
for stat in top_stats[:10]:
    print(stat)


Top 10 lines by memory usage:
/usr/lib/python3.10/codeop.py:118: size=256 B, count=3, average=85 B
/root/.local/lib/python3.10/site-packages/traitlets/traitlets.py:1514: size=168 B, count=1, average=168 B
/root/.local/lib/python3.10/site-packages/traitlets/traitlets.py:731: size=147 B, count=2, average=74 B
/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3519: size=112 B, count=1, average=112 B
/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3509: size=96 B, count=3, average=32 B
/usr/lib/python3.10/threading.py:313: size=88 B, count=2, average=44 B
/root/.local/lib/python3.10/site-packages/traitlets/traitlets.py:1543: size=72 B, count=1, average=72 B
/usr/lib/python3.10/threading.py:604: size=64 B, count=1, average=64 B
/root/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3568: size=64 B, count=1, average=64 B
/root/.local/lib/python3.10/site-packages/IPython/core/compilerop.py:192: size=28 B, count=1, average

算一下题目的最大token数量

In [1]:
import json
from transformers import AutoTokenizer

# 加载Qwen2.5-Coder-7B-Instruct的分词器
model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

max_token_count = 0
max_prompt = None

file_path = "dataset/LeetCodeDataset_postprocessed/LeetCodeDataset-v0.3.1-test.jsonl"

with open(file_path, 'r', encoding='utf-8') as f:
    for line in f:
        item = json.loads(line)
        prompt = item.get("prompt", "")
        token_count = len(tokenizer.encode(prompt))
        print(token_count)
        if token_count > max_token_count:
            max_token_count = token_count
            max_prompt = prompt

print(f"最大token数: {max_token_count}")
# 如需输出对应的prompt内容，可取消下一行注释
# print(f"对应的prompt: {max_prompt}")

  from .autonotebook import tqdm as notebook_tqdm


542
607
290
435
462
411
413
456
434
434
472
472
363
418
295
460
451
387
575
600
461
460
481
398
668
311
533
309
730
441
402
412
307
794
541
364
513
341
443
382
439
425
425
306
729
307
332
412
306
345
491
382
336
352
355
531
304
532
485
514
758
660
659
557
415
598
509
646
602
527
423
283
333
612
285
300
508
506
367
329
488
629
834
436
304
501
548
402
518
274
362
366
413
394
476
343
461
362
798
470
654
441
267
843
503
829
516
526
570
534
235
418
732
694
438
570
403
392
627
493
351
535
552
479
886
321
329
272
801
574
490
433
320
383
382
408
421
319
697
497
384
305
414
523
393
494
393
475
385
643
650
545
399
264
511
541
475
416
467
703
503
462
810
388
583
278
351
550
619
561
333
559
369
410
444
533
432
563
466
336
484
424
367
351
552
386
701
326
497
436
497
549
737
387
486
522
526
681
433
378
507
499
493
613
492
530
524
421
564
570
394
485
858
264
263
341
434
647
617
438
383
649
702
1307
528
372
374
581
最大token数: 1307
