In [10]:
# 把 /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/PKU-SafeRLHF/data 下所有的 train.jsonl 合成为一个combined_train_{num}.json, test.jsonl 同理
import os
import json
from glob import glob

data_dir = "/Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/PKU-SafeRLHF/data"

def is_valid_record(record):
    # 添加一层过滤措施，如果 某个 response 为 safe 另一个为unsafe但safe_response_id 没有正确反映，则过滤，
    try:
        # Assume presence of both safe flags and safer_response_id
        is_0_safe = record.get('is_response_0_safe')
        is_1_safe = record.get('is_response_1_safe')
        safer_id = record.get('safer_response_id')
        # Only filter if all relevant fields are present
        if (is_0_safe is not None) and (is_1_safe is not None) and (safer_id is not None):
            # case: response 0 safe, response 1 unsafe, but safer_response_id != 0 (should be 0)
            if is_0_safe is True and is_1_safe is False and safer_id != 0:
                return False
            # case: response 1 safe, response 0 unsafe, but safer_response_id != 1 (should be 1)
            if is_1_safe is True and is_0_safe is False and safer_id != 1:
                return False
        # Otherwise keep the record
        return True
    except Exception:
        # In case of any error, skip the record (conservative)
        return False

def combine_jsonl_files(data_dir, file_type):
    # file_type: 'train' or 'test'
    pattern = os.path.join(data_dir, f"**/{file_type}.jsonl")
    file_paths = glob(pattern, recursive=True)
    combined_data = []
    for fp in file_paths:
        with open(fp, "r", encoding="utf-8") as fin:
            for line in fin:
                try:
                    rec = json.loads(line)
                    if is_valid_record(rec):
                        combined_data.append(rec)
                except json.JSONDecodeError:
                    continue  # skip malformed lines
    out_path = os.path.join(data_dir, f"combined_{file_type}_{len(combined_data)}.json")
    with open(out_path, "w", encoding="utf-8") as fout:
        json.dump(combined_data, fout, ensure_ascii=False, indent=2)
    print(f"Written {file_type} data to {out_path} (total {len(combined_data)} records)")

combine_jsonl_files(data_dir, "train")
combine_jsonl_files(data_dir, "test")


Written train data to /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/PKU-SafeRLHF/data/combined_train_73907.json (total 73907 records)
Written test data to /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/PKU-SafeRLHF/data/combined_test_8211.json (total 8211 records)


In [11]:
import json

# 读取 combined_train JSON 文件
json_path = "/Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/combined_train_73907.json"

with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

safe_qa = []
unsafe_qa = []

for item in data:
    prompt = item.get('prompt', '')
    safer_response_id = item.get('safer_response_id')
    response_0 = item.get('response_0', '')
    response_1 = item.get('response_1', '')
    is_response_0_safe = item.get('is_response_0_safe', False)
    is_response_1_safe = item.get('is_response_1_safe', False)
    
    # 1. 处理 safe 数据：找 safer_response_id 对应的 response，如果它是 safe 的
    if safer_response_id is not None:
        if safer_response_id == 0 and is_response_0_safe:
            safe_qa.append({
                'prompt': prompt,
                'response': response_0
            })
        elif safer_response_id == 1 and is_response_1_safe:
            safe_qa.append({
                'prompt': prompt,
                'response': response_1
            })
    
    # 2. 处理 unsafe 数据：找一个 safe 为 false 的 response
    if not is_response_0_safe:
        unsafe_qa.append({
            'prompt': prompt,
            'response': response_0
        })
    elif not is_response_1_safe:
        unsafe_qa.append({
            'prompt': prompt,
            'response': response_1
        })

# 保存结果
output_dir = "/Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data"
safe_output_path = os.path.join(output_dir, f"safe_qa_{len(safe_qa)}.json")
unsafe_output_path = os.path.join(output_dir, f"unsafe_qa_{len(unsafe_qa)}.json")

with open(safe_output_path, "w", encoding="utf-8") as f:
    json.dump(safe_qa, f, ensure_ascii=False, indent=2)

with open(unsafe_output_path, "w", encoding="utf-8") as f:
    json.dump(unsafe_qa, f, ensure_ascii=False, indent=2)

print(f"已保存 {len(safe_qa)} 条安全数据到 {safe_output_path}")
print(f"已保存 {len(unsafe_qa)} 条非安全数据到 {unsafe_output_path}")

已保存 41251 条安全数据到 /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/safe_qa_41251.json
已保存 43469 条非安全数据到 /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/unsafe_qa_43469.json


分离完成：134185 条 safe_prompt，0 条 unsafe_prompt
safe_prompt 数据已保存至 safeqa_safe_prompt_134185.json
unsafe_prompt 数据已保存至 safeqa_unsafe_prompt_0.json


In [12]:
import json

# /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/safe_qa_41251.json 中抽出 50000 的数据，命名为 safe_qa_41251_sample_50k.json，且保证prompt之间不重复
# 从 /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/unsafe_qa_43469.json 中 抽 出 2000 条数据，且要和上面的safeqa prompt不要重复

# 读取 safe_qa 文件
safe_qa_path = "/Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/safe_qa_41251.json"
with open(safe_qa_path, "r", encoding="utf-8") as f:
    safe_qa_data = json.load(f)

# 去重 prompt，保留第一次出现的条目
seen_prompts = set()
unique_safe_qa = []
for item in safe_qa_data:
    prompt = item.get('prompt', '')
    if prompt and prompt not in seen_prompts:
        seen_prompts.add(prompt)
        unique_safe_qa.append(item)

# 取最多 50000 条（实际上只有 41251 条，所以会全部取出）
sample_size = min(50000, len(unique_safe_qa))
safe_qa_sample = unique_safe_qa[:sample_size]

# 保存 safe_qa 样本
output_dir = "/Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data"
safe_sample_path = os.path.join(output_dir, "safe_qa_41251_sample_50k.json")
with open(safe_sample_path, "w", encoding="utf-8") as f:
    json.dump(safe_qa_sample, f, ensure_ascii=False, indent=2)

print(f"Safe QA: 原始数据 {len(safe_qa_data)} 条")
print(f"Safe QA: 去重后 {len(unique_safe_qa)} 条")
print(f"Safe QA: 已保存 {len(safe_qa_sample)} 条到 {safe_sample_path}")

# 读取 unsafe_qa 文件
unsafe_qa_path = "/Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/unsafe_qa_43469.json"
with open(unsafe_qa_path, "r", encoding="utf-8") as f:
    unsafe_qa_data = json.load(f)

# 过滤掉与 safe_qa 中 prompt 重复的条目
safe_prompts = set(item['prompt'] for item in safe_qa_sample)
filtered_unsafe_qa = []
for item in unsafe_qa_data:
    prompt = item.get('prompt', '')
    if prompt and prompt not in safe_prompts:
        filtered_unsafe_qa.append(item)

# 取 2000 条
unsafe_sample_size = min(2000, len(filtered_unsafe_qa))
unsafe_qa_sample = filtered_unsafe_qa[:unsafe_sample_size]

# 保存 unsafe_qa 样本
unsafe_sample_path = os.path.join(output_dir, f"unsafe_qa_sample_{unsafe_sample_size}.json")
with open(unsafe_sample_path, "w", encoding="utf-8") as f:
    json.dump(unsafe_qa_sample, f, ensure_ascii=False, indent=2)

print(f"\nUnsafe QA: 原始数据 {len(unsafe_qa_data)} 条")
print(f"Unsafe QA: 过滤重复 prompt 后 {len(filtered_unsafe_qa)} 条")
print(f"Unsafe QA: 已保存 {len(unsafe_qa_sample)} 条到 {unsafe_sample_path}")

# 验证没有重复的 prompt
safe_sample_prompts = set(item['prompt'] for item in safe_qa_sample)
unsafe_sample_prompts = set(item['prompt'] for item in unsafe_qa_sample)
overlap = safe_sample_prompts.intersection(unsafe_sample_prompts)
print(f"\n验证: Safe 和 Unsafe 样本之间重复的 prompt 数量: {len(overlap)}")


Safe QA: 原始数据 41251 条
Safe QA: 去重后 25283 条
Safe QA: 已保存 25283 条到 /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/safe_qa_41251_sample_50k.json

Unsafe QA: 原始数据 43469 条
Unsafe QA: 过滤重复 prompt 后 25614 条
Unsafe QA: 已保存 2000 条到 /Users/boyuan/local_projects/wx_workspace/nature-resist/codespace/align-anything-resist/resist-collapse/dataset/raw_data/unsafe_qa_sample_2000.json

验证: Safe 和 Unsafe 样本之间重复的 prompt 数量: 0
