### Imports

In [None]:
import re
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from litellm import encode
from functools import partial
from multiprocessing import Pool
from datasets import load_dataset
from pandarallel import pandarallel
from swebench.harness.constants import TestStatus
from concurrent.futures import ThreadPoolExecutor
pandarallel.initialize(progress_bar=False, nb_workers=8)
tqdm.pandas()

### Helper Functions

In [None]:
def parse_status(content):
    if 'Some tests failed.' in content:
        return 0
    if 'All tests passed.' in content:
        return 1
    return -1

def get_token_count(text):
    return len(encode(model="gpt-4o", text=text))

def extract_code_blocks(text: str) -> list[str]:
   pattern = r'```(?:\w+)?\n(.*?)```'
   matches = re.findall(pattern, text, re.DOTALL) 
   return matches

def get_create_files(traj):
    create_files = []
    for idx, event in enumerate(traj):
        if event["text"] is None:
            continue
        action = extract_code_blocks(event["text"])
        if len(action) == 0:
            continue
        action = action[0]
        if action.split(" ")[0] == "create":
            if idx + 1 >= len(traj):
                continue
            create_file = traj[idx + 1]['text'].split('(Open file: ')[-1].split(')')[0]
            create_files.append(create_file)
    return create_files

def get_modified_files(traj):
    modified_files = []
    for idx, event in enumerate(traj):
        if event["text"] is None:
            continue
        action = extract_code_blocks(event["text"])
        if len(action) == 0:
            continue
        action = action[0]
        if action.split()[0] == "sed":
            modified_file = action.split(" ")[-1].split('\n')[0]
            modified_files.append(modified_file)
        elif action.split()[0] == "edit":
            if idx + 1 >= len(traj):
                continue
            modified_file = traj[idx + 1]['text'].split('(Open file: ')[-1].split(')')[0]
            modified_files.append(modified_file)
        elif action.split()[0] == "echo":
            if ">>" not in action:
                continue
            modified_file = action.split(" ")[-1].split('\n')[0]
            modified_files.append(modified_file)
    return modified_files

def get_files_from_patch(patch_content: str) -> list[str]:
    files = []
    lines = patch_content.split('\n')
    
    for line in lines:
        if line.startswith('diff --git'):
            # Extract both a/ and b/ paths
            parts = line.split()
            a_file = parts[2][2:]  # Remove a/ prefix
            b_file = parts[3][2:]  # Remove b/ prefix
            if a_file not in files:
                files.append(a_file)
            if b_file != a_file and b_file not in files:
                files.append(b_file)
    return files

def get_new_files_from_patch(patch_content: str) -> list[str]:
    """Get list of newly created files from git patch."""
    new_files = []
    lines = patch_content.splitlines()
    for i, line in enumerate(lines):
        if line.startswith('new file mode'):
            # Get the preceding diff line which contains the file path
            diff_line = lines[i-1]
            file_path = diff_line.split()[-1][2:]  # Remove b/ prefix
            new_files.append(file_path)
    return new_files

def get_relative_path(full_path):
    parts = full_path.split('/', 2)  # Split into [empty, repo_name, rest]
    if len(parts) >= 3:
        return parts[2]  # Return everything after repo_name
    return full_path.lstrip('/')  # Fallback: just remove leading slash

def filter_allowed_changes(patch_content, allowed_files):
    relative_allowed_files = {get_relative_path(path) for path in allowed_files}
    sections = patch_content.split('\ndiff --git ')
    
    if sections[0].startswith('diff --git '):
        sections[0] = sections[0][len('diff --git '):]
    
    filtered_sections = []
    
    for section in sections:
        if not section.strip():
            continue
        section = 'diff --git ' + section
            
        try:
            file_path = section.split(' b/', 1)[1].split('\n', 1)[0]
            file_path = file_path.strip()
            
            if file_path.startswith('a/'):
                file_path = file_path[2:]
            
            if file_path in relative_allowed_files:
                filtered_sections.append(section)
                
        except IndexError:
            continue
    
    result = ''.join(filtered_sections)
    
    if not result.strip():
        return ''
        
    return result

def get_deleted_files(patch_content):
    deleted_files = []
    current_file = None
    is_deleted = False
    for line in patch_content.splitlines():
        # Check for diff headers
        if line.startswith('diff --git'):
            current_file = None
            is_deleted = False
        # Check for file deletions
        elif line.startswith('deleted file mode'):
            is_deleted = True
        # Get the file path from --- line
        elif line.startswith('--- a/'):
            current_file = line[6:]  # Remove '--- a/' prefix
            if is_deleted and current_file:
                deleted_files.append(current_file)
    return deleted_files

def if_remove_file(file_path: str, deleted_files: set[str], created_files: set[str]) -> bool:
    EXCLUDE_PATTERNS = [
        r'reproduc', r'debug', r'fix', r'test',
        r'\.pyc$', r'\.cow$', r'\.zip$', r'\.whl$', r'\.dvc$', 
        r'\.deb$', r'\.dcm$', r'\.gz$', r'\.tar$',
        r'\.txt$', r'\.md$', r'\.rst$',
        r'\.\d+', r'gitignore',
        r'venv/bin/', r'eggs', r'dist', r'egg-info'
    ]
    if "htm" in file_path and (file_path in deleted_files or file_path in created_files):
        return True
    return any(re.search(pattern, file_path) for pattern in EXCLUDE_PATTERNS)

def clean_patch(patch, modified_files, created_files):
    all_files = get_files_from_patch(patch)
    deleted_files = set(get_deleted_files(patch))
    to_remove_files = [i for i in all_files if if_remove_file(i, deleted_files, created_files)]
    retain_files = list(set(modified_files) - set(created_files) - set(to_remove_files) - set(deleted_files))
    return filter_allowed_changes(patch, retain_files)

def parse_log_pytest(log: str) -> dict[str, str]:
    test_status_map = {}
    for line in log.split("\n"):
        if any([line.startswith(x.value) for x in TestStatus]):
            # Additional parsing for FAILED status
            if line.startswith(TestStatus.FAILED.value):
                line = line.replace(" - ", " ")
            test_case = line.split()
            if len(test_case) <= 1:
                continue
            test_status_map[test_case[1]] = test_case[0]
    return test_status_map

### Dataframe creation and cleaning

In [None]:
data = load_dataset('nebius/SWE-agent-trajectories', split='train')
df = pd.DataFrame(data)
df.fillna('', inplace=True)
df['status_msg'] = df['eval_logs'].apply(parse_status)
df = df[df['status_msg'] != -1].reset_index(drop=True)
df['created_file'] = df['trajectory'].progress_apply(get_create_files)
df['modified_file'] = df['trajectory'].progress_apply(get_modified_files)
df['generated_patch'] = df.parallel_apply(lambda x: clean_patch(x['generated_patch'], x['modified_file'], x['created_file']), axis=1)
df.drop_duplicates(subset=['generated_patch'], inplace=True, ignore_index=True)
df = df[df['generated_patch'] != ''].reset_index(drop=True)
df['token_count'] = df['generated_patch'].parallel_apply(get_token_count)
df['eval_token_count'] = df['eval_logs'].parallel_apply(get_token_count)
df = df[df['token_count'] <= 4000].reset_index(drop=True)
df['test_status'] = df['eval_logs'].parallel_apply(parse_log_pytest)
df = df[df['test_status'] != {}].reset_index(drop=True)
df.to_csv('reward_data/initial_reward_model_train.csv', index=False)
df = df[['instance_id', 'generated_patch', 'test_status', 'target', 'token_count']]

### Reward Data Prep

In [None]:
def evaluate(x):
    try:
        return eval(x)
    except:
        return x

def get_tests(test_status, tests):
    return [test for test in tests if test_status[test] == 'FAILED']

def get_bugfixing_status(f2p_failed, f2p):
    if len(f2p_failed) == 0:
        return 2
    if len(f2p) == len(f2p_failed):
        return 0
    return 1

def new_issues_status(p2p_failed, p2p):
    if len(p2p_failed) == 0:
        return 2
    if len(p2p) == len(p2p_failed):
        return 0
    return 1

In [None]:
data = load_dataset('nebius/SWE-bench-extra', split='train')
df_bench = pd.DataFrame(data)[['instance_id', 'patch', 'problem_statement', 'FAIL_TO_PASS', 'PASS_TO_PASS']]
data_swe = load_dataset('princeton-nlp/SWE-bench', split='dev')
df_swe = pd.DataFrame(data_swe)[['instance_id', 'patch', 'problem_statement', 'FAIL_TO_PASS', 'PASS_TO_PASS']]
df_bench = pd.concat([df_bench, df_swe], ignore_index=True)
df = df.merge(df_bench, on='instance_id', how='left')
df['FAIL_TO_PASS'] = df['FAIL_TO_PASS'].progress_apply(evaluate)
df['PASS_TO_PASS'] = df['PASS_TO_PASS'].progress_apply(evaluate)
df.fillna('', inplace=True)
df['P2P_failed'] = df.progress_apply(lambda x: get_tests(x['test_status'], x['PASS_TO_PASS']), axis=1)
df['F2P_failed'] = df.progress_apply(lambda x: get_tests(x['test_status'], x['FAIL_TO_PASS']), axis=1)
df['bug_fixing'] = df.apply(lambda x: get_bugfixing_status(x['F2P_failed'], x['FAIL_TO_PASS']), axis=1)
df['new_issues'] = df.apply(lambda x: new_issues_status(x['P2P_failed'], x['PASS_TO_PASS']), axis=1)
df.rename(columns={'P2P_failed': 'p2p_failed', 'F2P_failed': 'f2p_failed'}, inplace=True)
df.to_csv('reward_data/final_reward_model_train_before_aug.csv', index=False)

#### Critique Generation

After cleaning the patches, the next step is to generate critiques for each patch.  
For reference, consult the `readme.md` file located at:



In [None]:
df = pd.read_csv('reward_data/after_critique_generation_gt.csv')
df.fillna('', inplace=True)
df['test_status'] = df['test_status'].progress_apply(evaluate)
df['p2p_failed'] = df['p2p_failed'].progress_apply(evaluate)
df['f2p_failed'] = df['f2p_failed'].progress_apply(evaluate)

In [None]:
df = df[~((df['generated_patch'] == df['patch']) & (df['target'] == 0))]

In [None]:
import pandas as pd
from itertools import combinations
from collections import defaultdict 
from concurrent.futures import ProcessPoolExecutor
import random

def get_patch_info(row, if_gt=False):
    return {
        'patch': row['generated_patch'] if not if_gt else row['patch'],
        'target': row['target'] if not if_gt else True,
        'p2p_failed': row['p2p_failed'] if not if_gt else [],
        'f2p_failed': row['f2p_failed'] if not if_gt else [],
        'bug_fixing': row['bug_fixing'] if not if_gt else 2,
        'new_issues': row['new_issues'] if not if_gt else 2,
        'critique': row['critique'] if not if_gt else row['critique_gt'],
    }

def sample_incorrect_patches(failed_patches_dict, num_patches_needed, max_samples=3):
    total_patches = sum(len(patches) for patches in failed_patches_dict.values())
    if total_patches < num_patches_needed:
        all_patches = [p for patches in failed_patches_dict.values() for p in patches]
        return [random.sample(all_patches, min(num_patches_needed, len(all_patches))) 
                for _ in range(min(max_samples, len(all_patches)))]
    
    sampled_combinations = []
    for _ in range(max_samples):
        selected_patches = []
        buckets = list(failed_patches_dict.keys())
        patches_needed = num_patches_needed
        
        while patches_needed > 0 and buckets:
            total_remaining = sum(len(failed_patches_dict[k]) for k in buckets)
            allocations = []
            
            for bucket in buckets[:]:
                patches = failed_patches_dict[bucket]
                proportion = len(patches) / total_remaining
                num_from_bucket = min(
                    max(1, round(patches_needed * proportion)),
                    len(patches),
                    patches_needed
                )
                if num_from_bucket > 0:
                    allocations.append((bucket, num_from_bucket))
            
            random.shuffle(allocations)
            for bucket, num_to_select in allocations:
                selected = random.sample(failed_patches_dict[bucket], num_to_select)
                selected_patches.extend(selected)
                patches_needed -= num_to_select
                buckets.remove(bucket)
                
                if patches_needed == 0:
                    break
        
        if selected_patches:
            sampled_combinations.append(selected_patches)
    
    return sampled_combinations

def process_instance(group_data):
    instance_id, df_inst = group_data
    train_rows = []
    max_patches = 6
    
    ground_truth_patch = df_inst.iloc[0]
    gt_patch_info = get_patch_info(ground_truth_patch, if_gt=True)
    correct_patches = [get_patch_info(row) for _, row in df_inst[df_inst['target'] == 1].iterrows()]
    
    failed_patches_dict = defaultdict(list)
    for _, row in df_inst[df_inst['target'] == 0].iterrows():
        failed_tests = tuple(sorted(set(row['p2p_failed'] + row['f2p_failed'])))
        failed_patches_dict[failed_tests].append(get_patch_info(row))
    
    total_solved = len(correct_patches)
    
    for num_patches in range(2, max_patches + 1):
        possible_correct = list(set([0, min(1, total_solved)]))
        for n_correct in possible_correct:
            n_incorrect = num_patches - n_correct - 1  # -1 for ground truth
            
            if n_incorrect <= 0:
                continue
                
            if n_correct == 0:
                correct_combinations = [()]
            else:
                correct_combinations = list(combinations(correct_patches, n_correct))
            
            for correct_combo in correct_combinations:
                incorrect_patch_samples = sample_incorrect_patches(
                    failed_patches_dict,
                    n_incorrect,
                    max_samples=2
                )
                
                for incorrect_patches in incorrect_patch_samples:
                    if len(incorrect_patches) != n_incorrect:
                        continue
                        
                    row = {
                        'instance_id': instance_id,
                        'problem_statement': ground_truth_patch['problem_statement'],
                        'total_tokens': ground_truth_patch['token_count'],
                        'total_solved': total_solved,
                        'P': gt_patch_info
                    }
                    
                    for i in range(1, max_patches + 1):
                        row[f'P_correct{i}'] = ''
                        row[f'P_incorrect{i}'] = ''
                    
                    for i, patch in enumerate(correct_combo, 1):
                        row[f'P_correct{i}'] = patch
                    
                    for i, patch in enumerate(incorrect_patches, 1):
                        row[f'P_incorrect{i}'] = patch
                    
                    train_rows.append(row)
    
    return train_rows

def create_training_data(df, num_workers=None):
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(process_instance, df.groupby('instance_id')))
    
    all_rows = [row for instance_rows in results for row in instance_rows]
    return pd.DataFrame(all_rows).fillna('')

# Usage example:
train_df = create_training_data(df)

### Creating Final Manifest

In [None]:
for i, j in train_df.iterrows():
    st = set()
    st.add(j['P']['patch'])
    cnt = 1
    for k in range(1, 7):
        if j[f'P_correct{k}'] == '' and j[f'P_incorrect{k}'] == '':
            continue
        elif j[f'P_correct{k}'] == '':
            st.add(j[f'P_incorrect{k}']['patch'])
            cnt += 1
        elif j[f'P_incorrect{k}'] == '':
            st.add(j[f'P_correct{k}']['patch'])
            cnt += 1
    if len(st) != cnt:
        print(i)
        break

In [None]:
assert train_df['P'].apply(lambda x: x['target'] != 1).sum() == 0
train_df_new = train_df.copy()
train_df_new['P_correct1'] = train_df_new['P']
train_df = pd.concat([train_df, train_df_new], ignore_index=True)
train_df = train_df[train_df['P_correct1'] != ''].reset_index(drop=True)
train_df.to_csv('reward_data/train_df_augmented.csv', index=False)

### Creating Jsonl Files

In [None]:
train_df = pd.read_csv('reward_data/train_df_augmented.csv')
train_df.fillna('', inplace=True)
def evaluate(x):
    try:
        return eval(x)
    except:
        return x
for i in range(1, 10):
    train_df[f'P_correct{i}'] = train_df[f'P_correct{i}'].parallel_apply(evaluate)
    train_df[f'P_incorrect{i}'] = train_df[f'P_incorrect{i}'].parallel_apply(evaluate)

In [None]:
def get_analysis_message(bug_fixing, new_issues, patch_number):
    if bug_fixing == 0:
        bug_fixing_analysis = "The changes made in this patch are incorrect and fail to address the reported issue."
    elif bug_fixing == 1:
        bug_fixing_analysis = "The changes are partially correct. While they address the issue to some extent, they overlook critical corner cases."
    elif bug_fixing == 2:
        bug_fixing_analysis = "The changes made in this patch are correct and comprehensively resolve the reported issue."

    if new_issues == 0:
        regression_bug_analysis = "This patch introduces multiple regression bugs, potentially impacting other functionalities."
    elif new_issues == 1:
        regression_bug_analysis = "This patch introduces a few regression bugs, which require additional attention."
    elif new_issues == 2:
        regression_bug_analysis = "This patch does not introduce any regression bugs."

    message = f"""
<patch_analysis>
  <patch_number>{patch_number}</patch_number>
  <bug_fixing_analysis>
    {bug_fixing_analysis}
    <score>{bug_fixing}</score>
  </bug_fixing_analysis>
  <regression_risk_analysis>
    {regression_bug_analysis}
    <score>{new_issues}</score>
  </regression_risk_analysis>
</patch_analysis>
    """
    return message.strip()

In [None]:
def get_n_combinations(data, n):
    size = len(data)
    # Generate n different permutations at once using NumPy
    perms = np.array([np.random.permutation(size) for _ in range(n)])
    
    # Remove any duplicate permutations if they occur
    unique_perms = np.unique(perms, axis=0)
    while len(unique_perms) < n:
        new_perms = np.random.permutation(size).reshape(1, -1)
        unique_perms = np.unique(np.vstack([unique_perms, new_perms]), axis=0)
    
    return [[data[i] for i in perm] for perm in unique_perms[:n]]

def process_row(row_data):
    row, idx = row_data
    patches = []
    for i in range(1, 6):
        if row[f'P_correct{i}'] != '':
            patches.append(row[f'P_correct{i}'])
        if row[f'P_incorrect{i}'] != '':
            patches.append(row[f'P_incorrect{i}'])
    
    p_count = len(patches)
    if 3 < p_count <= 6:
        n_samples = 2
    else:
        n_samples = 1
    
    n_combinations = get_n_combinations(patches, n_samples)
    return [{
        'patches': combi,
        'problem_statement': row['problem_statement'],
        'instance_id': row['instance_id'],
    } for combi in n_combinations]

def create_train_samples(df):
    with Pool() as pool:
        results = pool.map(process_row, [(row, idx) for idx, row in df.iterrows()])
    
    new_rows = [item for sublist in results for item in sublist]
    return pd.DataFrame(new_rows)

In [None]:
final_df = create_train_samples(train_df)
final_df.to_csv('reward_data/final_df.csv', index=False)
final_df.patches.apply(len).hist()

In [None]:
def get_needed_info(patches):
    patch_info = []
    for idx, patch in enumerate(patches):
        patch_info.append({
            'patch': patch['patch'],
            'patch_analysis': get_analysis_message(patch['bug_fixing'], patch['new_issues'], idx + 1),
            'target': patch['target'],
            'critique': patch['critique'],
        })
    return patch_info

final_df['patches'] = final_df['patches'].parallel_apply(get_needed_info)

In [None]:
for i, j in final_df.sample(10).iterrows():
    print("Status: ", j['patches'][0]['target'])
    print("Critique: ", j['patches'][0]['critique'])
    print('-' * 50)

In [None]:
final_df.to_csv('reward_data/final_df_before_user_prompt.csv', index=False)

### Train & Test Set Prep

In [None]:
user_template = '''You are an expert software engineering evaluator tasked with analyzing and selecting the best solution patch for a GitHub issue. Your goal is to thoroughly evaluate each proposed patch and determine which one most effectively solves the issue.

First, examine the proposed patches:

<patches>
{patches}
</patches>

Now, carefully read the following GitHub issue:

<github_issue>
{issue}
</github_issue>

Your task is to evaluate each patch based on its effectiveness in fixing the bug. Use the following scoring system:

1. Bug Fixing Score (0-2):
   0: Incorrect changes that won't fix the issue
   1: Partially correct changes (might fix the issue but misses corner cases)
   2: Correct changes that will fully fix the issue

2. Regression Risk (Score 0-2):
   0: High risk of introducing multiple regression bugs
   1: Moderate risk of introducing a few regression bugs
   2: Low risk of introducing any regression bugs

For each patch, provide a detailed analysis using the following structure:

<patch_analysis>
  <patch_number>[Patch number]</patch_number>
  <bug_fixing_analysis>
    [Your analysis of the bug fixing approach]
    <score>[Score (0-2)]</score>
  </bug_fixing_analysis>
  <regression_risk_analysis>
    [Your analysis of potential regression risks]
    <score>[Score (0-2)]</score>
  </regression_risk_analysis>
</patch_analysis>

In your analysis, consider:
1. How well the patch addresses the core issue described in the GitHub issue?
2. Will the patch introduce regression bugs or other issues?

Please reason step by step, and put your final answer in the following format:

<best_patch>Number of the best patch</best_patch>

Important Tips:
1. For each patch, write down the main changes and their potential impact
2. Consider potential edge cases and how each patch addresses them
3. Please do not prioritize one patch over another based on the position in the list.
'''

solution_template = '''<think>
Here is the analysis of the proposed solution patches:

<analysis>
{analysis}
</analysis>

Based on the analysis, here is the evaluation for each patch:

<scoring>
{scoring}
</scoring>

Based on the scoring system, the best solution patch is:

<best_patch>{best_patch_number}</best_patch>

</think>
'''

In [None]:
def get_analysis_message(bug_fixing, new_issues, patch_number):
    if bug_fixing == 0:
        bug_fixing_analysis = "The changes made in this patch are incorrect and fail to address the reported issue."
    elif bug_fixing == 1:
        bug_fixing_analysis = "The changes are partially correct. While they address the issue to some extent, they overlook critical corner cases."
    elif bug_fixing == 2:
        bug_fixing_analysis = "The changes made in this patch are correct and comprehensively resolve the reported issue."

    if new_issues == 0:
        regression_bug_analysis = "This patch introduces multiple regression bugs, potentially impacting other functionalities."
    elif new_issues == 1:
        regression_bug_analysis = "This patch introduces a few regression bugs, which require additional attention."
    elif new_issues == 2:
        regression_bug_analysis = "This patch does not introduce any regression bugs."

    message = f"""
<patch_analysis>
  <patch_number>{patch_number}</patch_number>
  <bug_fixing_analysis>
    {bug_fixing_analysis}
    <score>{bug_fixing}</score>
  </bug_fixing_analysis>
  <regression_risk_analysis>
    {regression_bug_analysis}
    <score>{new_issues}</score>
  </regression_risk_analysis>
</patch_analysis>
    """
    return message.strip()

In [None]:
df = pd.read_csv('reward_data/final_df_before_user_prompt.csv')
df['patches'] = df['patches'].parallel_apply(eval)

#### Train Prep

In [None]:
def handle_patch_context(patch):
    while patch[-1] == '\n':
        patch = patch[:-1]
    return f'\n<patch>\n{patch}\n<\patch>\n'

def get_user_prompt(patches, issue):
    all_patches = [handle_patch_context(patch['patch']) for patch in patches]
    
    patches_str = "\n".join([f"Solution Patch {idx + 1}: {patch}" for idx, patch in enumerate(all_patches)])
    return user_template.format(patches=patches_str, issue=issue)

def get_solution_prompt(patches):
    scoring = "\n".join([patch['patch_analysis'] for patch in patches])
    solved_idx = -1
    for idx, patch in enumerate(patches):
        if patch['target'] == 1:
            solved_idx = idx + 1
            break
    analysis = "\n".join([f"Solution Patch {idx + 1}: {patch['critique']}" for idx, patch in enumerate(patches)])
    return solution_template.format(analysis=analysis, best_patch_number=solved_idx, scoring=scoring)
df['user_prompt'] = df.parallel_apply(lambda x: get_user_prompt(x['patches'], x['problem_statement']), axis=1)
df['solution_prompt'] = df['patches'].parallel_apply(get_solution_prompt)

for i, j in df.sample(2).iterrows():
    print(j['solution_prompt'])
    print('=' * 100)

df.to_csv('reward_data/final_df_user_prompt.csv', index=False)
df = df.sample(frac=1).reset_index(drop=True)


In [None]:
def convert_row_to_chat_format(row):
    return {
        "messages": [
            {"role": "user", "content": row['user_prompt']},
            {"role": "assistant", "content": row['solution_prompt']}
        ],
        "format": "chatml"
    }

df['chat_data'] = df.parallel_apply(convert_row_to_chat_format, axis=1)

In [None]:
final_list = df['chat_data'].tolist()
with open('reward_data/train.jsonl', 'w', encoding='utf-8') as f:
    for item in final_list:
        f.write(json.dumps(item) + '\n')