In [None]:
import os
import json

data = []

collected_output_dir = './collected_output_14B_NEW/tree_1'

json_files = [f for f in os.listdir(collected_output_dir) if f.endswith('.json')]

for json_file in json_files:
    file_path = os.path.join(collected_output_dir, json_file)
    with open(file_path, 'r', encoding='utf-8') as f:
        content = json.load(f)
        data.append(content)

print("读取到的 JSON 文件数量:", len(data))

In [None]:
# 统计其中含有 value_sum > 0 的 sample 数量
positive_sum_count = sum(1 for sample in data if any(path['value_sum'] > 0 for path in sample))
print(f"含有 value_sum > 0 的 sample 数量: {positive_sum_count}")

# 统计含有 value_sum > 0 的 sample 中的 value_sum > 0 的路径数量
positive_value_sum_count = sum(
    sum(1 for path in sample if path['value_sum'] > 0) for sample in data if any(path['value_sum'] > 0 for path in sample)
)
print(f"含有 value_sum > 0 的 sample 中的 value_sum > 0 的路径数量: {positive_value_sum_count}")

In [None]:
# 统计其中 q value = value sum / visit_count > 0 且 is_terminal = True 的 路径数量

q_value_positive_terminal_count = sum(
    sum(1 for path in sample if path['is_terminal'] and path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] > 0)
    for sample in data if any(path['is_terminal'] and path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] > 0 for path in sample)
)
print(f"q value = value sum / visit count > 0 且 is_terminal = True 的路径数量: {q_value_positive_terminal_count}")

In [None]:
# 统计 q value = value sum / visit_count > 0 且 is_terminal = True 的 path 的长度分布

q_value_positive_terminal_lengths = [
    len(path['messages']) for sample in data for path in sample if path['is_terminal'] and path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] > 0
]

q_value_positive_terminal_length_counts = {}
for length in q_value_positive_terminal_lengths:
    if length in q_value_positive_terminal_length_counts:
        q_value_positive_terminal_length_counts[length] += 1
    else:
        q_value_positive_terminal_length_counts[length] = 1

# 打印前10个出现次数最多的长度及其出现次数
sorted_q_value_positive_terminal_length_counts = sorted(q_value_positive_terminal_length_counts.items(), key=lambda x: x[1], reverse=True)
print("q value = value sum / visit count > 0 且 is_terminal = True 的路径长度分布:")
for length, count in sorted_q_value_positive_terminal_length_counts[:10]:
    print(f"长度: {length}, 出现次数: {count}")

In [None]:
# 统计其中 q_value = value sum / visit_count < 0 且 is_terminal = True 的 路径数量
q_value_negative_terminal_count = sum(
    sum(1 for path in sample if path['is_terminal'] and path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] < 0)
    for sample in data if any(path['is_terminal'] and path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] < 0 for path in sample)
)
print(f"q value = value sum / visit count < 0 且 is_terminal = True 的路径数量: {q_value_negative_terminal_count}")

In [None]:
import random

# 对每个问题，最多抽取 4 个正确终止路径和 4 个错误终止路径
final_paths = []

for sample in data:
    correct_paths = []
    incorrect_paths = []
    
    for path in sample:
        if path['is_terminal']:
            if path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] > 0:
                correct_paths.append(path)
            elif path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] < 0 and 'Code execution failed with status: Finished' not in str(path['messages']) and 'TimeLimitExceeded' not in str(path['messages']):
                incorrect_paths.append(path)
    
    # 限制每个问题最多抽取 4 个正确终止路径和 4 个错误终止路径。随机抽取固定随机数种子
    random.seed(42)
    final_correct_paths = random.sample(correct_paths, min(4, len(correct_paths)))
    final_incorrect_paths = random.sample(incorrect_paths, min(4, len(incorrect_paths)))

    # 如果没有正确终止路径，则错误终止路径随机抽取 1 个
    if not final_correct_paths and final_incorrect_paths:
        final_incorrect_paths = random.sample(final_incorrect_paths, min(1, len(final_incorrect_paths)))
    

    final_paths.extend(final_correct_paths)
    final_paths.extend(final_incorrect_paths)

print(f"最终抽取的路径数量: {len(final_paths)}")

In [None]:
# 提取 final_paths 中正确和错误的路径数量
correct_count = sum(1 for path in final_paths if path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] > 0)

incorrect_count = sum(1 for path in final_paths if path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] < 0)

print(f"正确路径数量: {correct_count}")
print(f"错误路径数量: {incorrect_count}")

In [None]:
# 统计正确路径和错误路径的路径长度分布

correct_lengths = [len(path['messages']) for path in final_paths if path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] > 0]

incorrect_lengths = [len(path['messages']) for path in final_paths if path['visit_count'] > 0 and path['value_sum'] / path['visit_count'] < 0]

# 统计频数
correct_length_counts = {}
for length in correct_lengths:
    if length in correct_length_counts:
        correct_length_counts[length] += 1
    else:
        correct_length_counts[length] = 1

incorrect_length_counts = {}
for length in incorrect_lengths:
    if length in incorrect_length_counts:
        incorrect_length_counts[length] += 1
    else:
        incorrect_length_counts[length] = 1

# 正确路径长度及其出现次数
sorted_correct_length_counts = sorted(correct_length_counts.items(), key=lambda x: x[1], reverse=True)
print("正确路径长度及其出现次数:")
for length, count in sorted_correct_length_counts:
    print(f"长度: {length}, 出现次数: {count}")
# 错误路径长度及其出现次数
sorted_incorrect_length_counts = sorted(incorrect_length_counts.items(), key=lambda x: x[1], reverse=True)
print("错误路径长度及其出现次数:")
for length, count in sorted_incorrect_length_counts:
    print(f"长度: {length}, 出现次数: {count}")


In [None]:
final_paths[0]['messages']

In [None]:
import random

final_nodes = []

random.seed(42)

for sample in data:
    correct_paths = []
    incorrect_paths = []
    
    final_nodes_tags = set()
    
    for path in sample:
        if path['is_terminal']:
            avg_value = path['value_sum'] / path['visit_count'] if path['visit_count'] > 0 else 0
            if avg_value > 0:
                correct_paths.append(path)
            elif avg_value < 0 and 'Code execution failed with status: Finished' not in str(path['messages']) and 'TimeLimitExceeded' not in str(path['messages']):
                incorrect_paths.append(path)
    
    final_correct_paths = random.sample(correct_paths, min(4, len(correct_paths)))
    final_incorrect_paths = random.sample(incorrect_paths, min(4, len(incorrect_paths)))

    if not final_correct_paths and final_incorrect_paths:
        final_incorrect_paths = random.sample(final_incorrect_paths, min(2, len(final_incorrect_paths)))

    # 加入终止节点的 tag
    for path in final_correct_paths + final_incorrect_paths:
        final_nodes_tags.add(path['tag'])
        # 加入所有中间 tag
        if path['tag']:
            parts = path['tag'].split('.')
            for i in range(1, len(parts)):
                intermediate_tag = '.'.join(parts[:i])
                final_nodes_tags.add(intermediate_tag)

    # 建立 tag 到 path 映射，防止重复遍历
    tag2path = {p['tag']: p for p in sample if p['tag']}

    for tag in final_nodes_tags:
        if tag in tag2path:
            final_nodes.append(tag2path[tag])

print(f"最终抽取的节点数量: {len(final_nodes)}")

In [None]:
for node in final_nodes:
    node['q_value'] = node['value_sum'] / node['visit_count']

In [None]:
final_nodes[0]

In [None]:
# 查看 q_value 的分布
q_values = [node['q_value'] for node in final_nodes]

# 打印 q_value 的不同值分布
import collections

q_value_counter = collections.Counter(q_values)
for q_value, count in q_value_counter.most_common(15):
    print(f"Q Value: {q_value}, Count: {count}")

# 画曲线分布图
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.hist(q_values, bins=1000, alpha=0.75, color='blue', edgecolor='black')
plt.xlabel('Q Value')
plt.ylabel('Frequency')
plt.title('Distribution of Q Values')
plt.grid(axis='y', alpha=0.75)
plt.show()

In [None]:
import json

output_estimate_file = "estimate_round1_14B.jsonl"

def format_messages_simple(messages):
    # 按 role: content 换行拼接
    return "\n\n".join([f"{m['role']}:\n{m['content']}" for m in messages])

train_data = []
for path in final_nodes:
    messages = path['messages']
    prompt = format_messages_simple(messages)
    q_value = path['q_value']
    # 过滤空样本
    if not prompt.strip():
        continue
    train_data.append({"prompt": prompt, "value": q_value})

    #print(f"样本长度: {len(messages)}, q_value: {q_value}, prompt: {prompt[:50]}...")  # 打印前50个字符

with open(output_estimate_file, "w", encoding="utf8") as fout:
    for d in train_data:
        fout.write(json.dumps(d, ensure_ascii=False) + "\n")

# 加载数据并查看
with open(output_estimate_file, "r", encoding="utf8") as fin:
    for line in fin:
        data = json.loads(line)
        print(data['prompt'], data['value'])
        break