## Imports

In [None]:
import os
import re
import json
import logging
import asyncio
import datasets
import pandas as pd
import nest_asyncio
from tqdm import tqdm
from datetime import datetime
from asyncio import Semaphore
from litellm import completion
from tqdm.notebook import tqdm
from dataclasses import dataclass
from typing import List, Dict, Any
from pydantic import BaseModel, Field
from concurrent.futures import ThreadPoolExecutor
from utils import get_all_accepted_paths, get_all_rejected_paths

In [None]:
os.environ['OPENAI_API_KEY'] = ''
os.environ['DEEPSEEK_API_KEY'] = ''
os.environ['GEMINI_API_KEY'] = ''
os.environ['AWS_ACCESS_KEY_ID'] = ''
os.environ['AWS_SECRET_ACCESS_KEY'] = ''
os.environ['AWS_DEFAULT_REGION'] = ''

In [None]:
# add root files and evaluation files paths
traj_files = sorted(['/path/to/root1', '/path/to/root2', '...'])
eval_files = sorted(['/path/to/eval1', '/path/to/eval2', '...'])

In [None]:
act_obs = {}
for traj_file, eval_file in zip(traj_files, eval_files):
    issue1 = traj_file.split('/')[-2]
    issue2 = eval_file.split('/')[-2]
    if issue1 != issue2:
        print(issue1, issue2)
    assert issue1 == issue2
    issue = issue1
    with open(traj_file, 'r') as f:
        traj = json.load(f)
    with open(eval_file, 'r') as f:
        eval_dict = json.load(f)
    all_accepted_paths = get_all_accepted_paths(traj['root'], eval_dict)
    all_rejected_paths = get_all_rejected_paths(traj['root'], eval_dict, issue)
    act_obs[issue] = []
    for path in all_accepted_paths:
        ac_ob = []
        for i, node in enumerate(path):
            if node['role'] == 'assistant':
                ac_ob.append((node['action'], path[i+1]['content'], node['thought']))
        act_obs[issue].append((ac_ob, 1))
    for path in all_rejected_paths:
        ac_ob = []
        for i, node in enumerate(path):
            if node['role'] == 'assistant':
                ac_ob.append((node['action'], path[i+1]['content'], node['thought']))
        act_obs[issue].append((ac_ob, 0))

## Context from trajectory

### Clean Trajectory

In [None]:
def clean_trajectories(act_obs):
    new_act_obs = []
    for ac_ob, status in act_obs:
        new_ac_ob = []
        for act, obs, thought in ac_ob:
            if not act:
                continue
            action = act.split()[0]
            if 'submit' in action: 
                action = 'submit'
            if action not in ['append', 'create', 'edit', 'undo_edit', 'submit', 'insert', 'execute_ipython', 'execute_server', 'django-admin', 'pylint', 'pytest', 'python', 'python2', 'python3', 'sphinx-build', 'tox']:
                continue
            if action in ['execute_ipython', 'execute_server']:
                if "Error: Server script " in obs or "Error starting server process:" in obs or "Error connecting to server:" or "No response from server." in obs or "Error: 'code' must not be empty." in obs or "Usage: execute_ipython $<code>" or "Error: Server script" in obs or "Error: Server did not start within expected time." in obs or "Server not running, starting server..." in obs or "Error connecting to server after restart:" in obs or "Error connecting to server:" in obs or "No response from server." in obs or "Error during communication with server:" in obs:
                    continue
            if action in ['python', 'python2', 'python3']:
                if 'No such file or directory\n(Open file:' in obs:
                    continue
            if action in ['edit', 'append', 'insert'] and 'File updated ' not in obs:
                continue
            if action == 'create' and ('Usage: create ' in obs or 'Error: File ' in obs):
                continue
            if obs.startswith('/bin/bash:'):
                continue
            if '\nEXECUTION TIMED OUT. If the command is intended to run in the background or requires more time to complete' in obs:
                continue
            if obs.startswith('/root/commands/defaults.sh'):
                continue
            if action == 'execute_ipython' and "Error: 'code' must not be empty." in obs:
                continue
            if action == 'execute_server' and ('Error starting server process:' in obs or 'Error: Server' in obs or 'Error connecting to server' in obs or 'Error during communication with server' in obs):
                continue
            if action == 'undo_edit' and 'Successfully restored ' not in obs:
                continue
            
            new_ac_ob.append((act, obs, thought))
        new_act_obs.append((new_ac_ob, status))

    return new_act_obs

In [None]:
avg_len = 0
total_paths = 0
for issue in act_obs:
    total_paths += len(act_obs[issue])
    for ac_ob, _ in act_obs[issue]:
        avg_len += len(ac_ob)
avg_len / total_paths

In [None]:
new_act_obs = {}
for issue in act_obs:
    new_act_obs[issue] = clean_trajectories(act_obs[issue])
avg_len = 0
total_paths = 0
for issue in new_act_obs:
    total_paths += len(new_act_obs[issue])
    for ac_ob, _ in new_act_obs[issue]:
        avg_len += len(ac_ob)
round(avg_len / total_paths, 2)

In [None]:
avg_paths = 0
avg_paths_resolved = 0
total_resolved = 0
success_ratio = []
for issue in new_act_obs:
    avg_paths += len(new_act_obs[issue])
    success = 0
    for path, status in new_act_obs[issue]:
        if status == 1:
            success += 1
    if success != 0:
        success_ratio.append(success / len(new_act_obs[issue])) 
        avg_paths_resolved += len(new_act_obs[issue])
        total_resolved += 1
round(avg_paths / len(new_act_obs), 1), round(sum(success_ratio) / len(success_ratio), 2), round(avg_paths_resolved / total_resolved, 1)

In [None]:
def get_unique_trajectories(new_act_obs):
    all_unique_trajs = {}
    
    for issue in tqdm(new_act_obs):
        # Dictionary to track shortest trajectory for each patch in this issue
        patch_to_traj = {}
        
        # Sort trajectories by length
        sorted_trajs = sorted(new_act_obs[issue], key=lambda x: len(x[0]))
        
        # Process each trajectory
        for traj in sorted_trajs:
            action_obs = traj[0]
            patch = action_obs[-1][1].split('(Open file:')[0]
            if patch not in patch_to_traj:
                patch_to_traj[patch] = traj
        
        # Store unique trajectories for this issue
        all_unique_trajs[issue] = list(patch_to_traj.values())
    
    return all_unique_trajs

all_unique_trajs = get_unique_trajectories(new_act_obs)

In [None]:
avg_paths = 0
avg_paths_resolved = 0
total_resolved = 0
success_ratio = []
for issue in all_unique_trajs:
    avg_paths += len(all_unique_trajs[issue])
    success = 0
    for path, status in all_unique_trajs[issue]:
        if status == 1:
            success += 1
    if success != 0:
        success_ratio.append(success / len(all_unique_trajs[issue]))
        avg_paths_resolved += len(all_unique_trajs[issue])
        total_resolved += 1

round(avg_paths / len(all_unique_trajs), 1), round(sum(success_ratio) / len(success_ratio), 2), round(avg_paths_resolved / total_resolved, 1)

### Reproduction Script & Main Updates

In [None]:
def get_all_created_files(traj):
    created_files = []
    for act, obs, thought in traj:
        action = act.split()[0]
        if action == 'create':
            created_files.append(obs.split('Open file: ')[-1].split(')\n')[0])
    return created_files 

In [None]:
def get_all_edits_files(traj):
    edited_files = []
    for idx, (act, obs, _) in enumerate(traj):
        action = act.split()[0]
        if action == 'edit' or action == 'append' or action == 'insert':
            edited_files.append((obs, idx))
    return edited_files

In [None]:
pattern = r'^\[File:.*?\([0-9]+\s+lines\s+total.*?\)]\s*'
all_created_files = {}
last_edit_index = {}
main_edited_files = {}
for issue in all_unique_trajs:
    all_created_files[issue] = []
    last_edit_index[issue] = []
    main_edited_files[issue] = []
    for traj, status in all_unique_trajs[issue]:
        last_edit_idx = {}
        reproduce_file_content = {}
        main_edit_files = {}
        created_files = list(set(get_all_created_files(traj)))
        edited_files = list(get_all_edits_files(traj))
        for file in created_files:
            last_edit_id = -1
            for obs, idx in edited_files:
                edited_file = obs.split('(Open file: ')[-1].split(')')[0]
                if file == edited_file:
                    reproduce_file_content[file] = re.sub(pattern, '', obs.split('[File updated')[0])
                    last_edit_id = idx
            last_edit_idx[file] = last_edit_id
        for obs, idx in edited_files:
            edited_file = obs.split('(Open file: ')[-1].split(')')[0]
            if edited_file not in created_files:
                main_edit_files[edited_file] = re.sub(pattern, '', obs.split('[File updated')[0])
        all_created_files[issue].append(reproduce_file_content)
        last_edit_index[issue].append(last_edit_idx)
        main_edited_files[issue].append((main_edit_files, status))

In [None]:
r_p = 0
unq = set()
for issue in main_edited_files:
    st = 0
    for traj, status in main_edited_files[issue]:
        st = st | status
    for traj, status in main_edited_files[issue]:
        if len(traj) == 0 and st:
            r_p += 1
            unq.add(issue)
r_p

### Final Execution Output 

In [None]:
def get_execution_actions(traj):
    execution_actions = []
    for idx, (act, obs, _) in enumerate(traj):
        action = act.split()[0]
        if action in ['execute_ipython', 'execute_server', 'python', 'python2', 'python3', 'django-admin', 'pylint', 'pytest', 'sphinx-build', 'tox']:
            execution_actions.append((act, obs, idx))
    return execution_actions

In [None]:
execution_action_outputs = {}
for issue in all_unique_trajs:
    edit_locations = last_edit_index[issue]
    execution_action_outputs[issue] = []
    for id, (traj, _) in enumerate(all_unique_trajs[issue]):
        edit_location = edit_locations[id]
        cur_info = {}
        reproduce_files = list(all_created_files[issue][id].keys())
        execution_actions = get_execution_actions(traj)
        for act, obs, idx in execution_actions:
            if act.startswith('execute_ipython'):
                cur_info['execute_ipython'] = {'action': act, 'output': obs.split('(Open file: ')[0]}
            for f in reproduce_files:
                if f.split('/')[-1] in " ".join(act.split()[1:]):
                    edit_idx = edit_location[f]
                    if idx <= edit_idx:
                        continue
                    cur_info[f] = {'action': act, 'output': obs.split('(Open file: ')[0]}
                    
        execution_action_outputs[issue].append(cur_info)

### Final Patch

In [None]:
def get_relative_path(full_path):
    """
    Convert repository path to relative path.
    E.g., /repo_name/path/to/file -> path/to/file
    """
    parts = full_path.split('/', 2)  # Split into [empty, repo_name, rest]
    if len(parts) >= 3:
        return parts[2]  # Return everything after repo_name
    return full_path.lstrip('/')  # Fallback: just remove leading slash

def filter_test_changes(patch_content, test_files):
    
    relative_test_files = {get_relative_path(path) for path in test_files}
    sections = patch_content.split('\ndiff --git ')
    
    if sections[0].startswith('diff --git '):
        sections[0] = sections[0][len('diff --git '):]
    
    filtered_sections = []
    
    for section in sections:
        if not section.strip():
            continue
        if filtered_sections:
            section = 'diff --git ' + section
        else:
            section = 'diff --git ' + section
            
        try:
            file_path = section.split(' b/', 1)[1].split('\n', 1)[0]
            file_path = file_path.strip()
        
            if file_path.startswith('a/'):
                file_path = file_path[2:]
            
            if file_path not in relative_test_files:
                filtered_sections.append(section)
                
        except IndexError:
            filtered_sections.append(section)
    
    result = ''.join(filtered_sections)
    
    if not result.strip():
        return ''
        
    return result

def filter_allowed_changes(patch_content, allowed_files):

    relative_allowed_files = {get_relative_path(path) for path in allowed_files}
    sections = patch_content.split('\ndiff --git ')
    
    if sections[0].startswith('diff --git '):
        sections[0] = sections[0][len('diff --git '):]
    
    filtered_sections = []
    
    for section in sections:
        if not section.strip():
            continue
        section = 'diff --git ' + section
            
        try:
            file_path = section.split(' b/', 1)[1].split('\n', 1)[0]
            file_path = file_path.strip()
            
            if file_path.startswith('a/'):
                file_path = file_path[2:]
            
            if file_path in relative_allowed_files:
                filtered_sections.append(section)
                
        except IndexError:
            continue
    
    result = ''.join(filtered_sections)
    
    if not result.strip():
        return ''
        
    return result

In [None]:
all_patches = {}
cleaned_patches_to_original = {}
for issue in all_unique_trajs:
    all_patches[issue] = []
    cleaned_patches_to_original[issue] = {}
    for id, (traj, status) in enumerate(all_unique_trajs[issue]):
        patch =  traj[-1][1].split('(Open file: ')[0]
        all_main_files = list(main_edited_files[issue][id][0].keys())
        clean_patch = filter_allowed_changes(patch, all_main_files)
        all_patches[issue].append((clean_patch, status))
        if clean_patch not in cleaned_patches_to_original[issue]:
            cleaned_patches_to_original[issue][clean_patch] = [patch]
        else:
            cleaned_patches_to_original[issue][clean_patch].append(patch)
    cleaned_patches_to_original[issue] = {k: v for k, v in sorted(cleaned_patches_to_original[issue].items(), key=lambda item: len(item[1]))}

In [None]:
all_unique_patches = {}
for issue in all_patches:
    all_unique_patches[issue] = list(set(all_patches[issue]))

In [None]:
l = 0
solved = []
for issue in all_unique_patches:
    traj_patches = all_unique_patches[issue]
    status = 0
    l += len(traj_patches)
    for _, stat in traj_patches:
        status = status | stat
    if not status:
        continue
    solve = 0
    total = len(traj_patches)
    for patch, stat in traj_patches:
        if stat:
            solve += 1
    solved.append(solve / total)

In [None]:
trajs_context = {}
for issue in all_unique_trajs:
    trajs_context[issue] = []
    created_files = all_created_files[issue]
    main_edit_files = main_edited_files[issue]
    execution_actions = execution_action_outputs[issue]
    patches = all_patches[issue]
    for created_file, (main_edit_file, status), execution_action, (patch, _) in zip(created_files, main_edit_files, execution_actions, patches):
        trajs_context[issue].append((created_file, main_edit_file, execution_action, patch, status))

## Calling API

### Creating Context Prompt

In [None]:
def handle_reproduction_context(created_file):
    if not created_file:
        return "No Reproduction Script Created."
    c_f = "\n<reproduction_script>\n"
    for i in created_file:
        c_f += f'File: {i}\n'
        c_f += f'Content:\n{created_file[i]}\n\n'
    return c_f + '\n</reproduction_script>\n'

def handle_bugfix_context(main_edit_file):
    if not main_edit_file:
        return "No edits are made to fix the issue."
    c_f = "\n<updated_file>\n"
    for i in main_edit_file:
        c_f += f'File: {i}\n'
        c_f += f'Updated Content:\n{main_edit_file[i]}\n\n'
    return c_f + '\n</updated_file>\n'

def handle_execution_context(execution_action):
    if not execution_action:
        return "Final edits are not tested before submission."
    c_f = "\n<execution_output>\n"
    for i in execution_action:
        if i == 'execute_ipython':
            c_f += f'Following code executed in jupyter notebook: {" ".join(execution_action[i]["action"].split()[1:])}\n'
            c_f += f'Output:\n{execution_action[i]["output"]}\n\n'
        else:
            c_f += f'Execution Command: {execution_action[i]["action"]}\n'
            c_f += f'Output:\n{execution_action[i]["output"]}\n\n'
    return c_f + '\n</execution_output>\n'

def handle_patch_context(patch):
    if not patch:
        return "No patch is submitted."
    while patch[-1] == '\n':
        patch = patch[:-1]
    return f'\n<patch>\n{patch}\n<\patch>\n'

In [None]:
trajs_final_context = {}
for issue in trajs_context:
    trajs_final_context[issue] = []
    for created_file, main_edit_file, execution_action, patch, status in trajs_context[issue]:
        context = ""
        # context += handle_reproduction_context(created_file) # uncomment if adding reproduction script while evaluation
        # context += handle_bugfix_context(main_edit_file) # uncomment if adding edited script while evaluation
        # context += handle_execution_context(execution_action) # uncomment if adding execution output while evaluation
        context += handle_patch_context(patch)
        trajs_final_context[issue].append((context, status))

In [None]:
def count_components(context):
    """Count the number of non-empty components in a trajectory."""
    components = {
        "reproduction": "No Reproduction Script Created." not in context,
        "bugfix": "No edits are made to fix the issue." not in context,
        "execution": "Final edits are not tested before submission." not in context,
        "patch": "No patch is submitted." not in context
    }
    return sum(components.values()), components

def get_component_score(components):
    """Calculate score based on component preferences."""
    score = 0
    if components["bugfix"]:
        score += 4
    if components["reproduction"]:
        score += 2
    if components["execution"]:
        score += 1
    return score

def filter_trajectories(trajs_final_context):
    filtered_trajs_final_context = {}
    for issue in trajs_final_context:
        # Group trajectories by patch
        patch_groups = {}
        for context, status in trajs_final_context[issue]:
            # Skip if no patch
            if "No patch is submitted." in context:
                continue
                
            # Extract patch from context
            patch_start = context.find('\n<patch>\n')
            if patch_start == -1:
                continue
            patch = context[patch_start:]
            
            if patch not in patch_groups:
                patch_groups[patch] = []
            patch_groups[patch].append((context, status))

        # Process each patch group
        best_trajectories = []
        for patch, trajectories in patch_groups.items():
            # Find trajectory with best components
            max_components = 0
            max_component_score = 0
            min_context_length = float('inf')
            best_trajectory = None
            
            for context, status in trajectories:
                num_components, components = count_components(context)
                component_score = get_component_score(components)
                context_length = len(context)
                
                # Update best trajectory based on criteria
                if num_components > max_components:
                    max_components = num_components
                    max_component_score = component_score
                    min_context_length = context_length
                    best_trajectory = (context, status)
                elif num_components == max_components:
                    if component_score > max_component_score:
                        max_component_score = component_score
                        min_context_length = context_length
                        best_trajectory = (context, status)
                    elif component_score == max_component_score:
                        if context_length < min_context_length:
                            min_context_length = context_length
                            best_trajectory = (context, status)
            
            if best_trajectory:
                best_trajectories.append(best_trajectory)
        
        filtered_trajs_final_context[issue] = best_trajectories
    
    return filtered_trajs_final_context

filtered_trajs_final_context = filter_trajectories(trajs_final_context)

In [None]:
solved = []
solved_issues = set()
solved_ratio = []
for issue in filtered_trajs_final_context:
    solve = 0
    total = len(filtered_trajs_final_context[issue])
    for context, status in filtered_trajs_final_context[issue]:
        if status:
            solve += 1
    if solve != 0:
        solved.append((solve, total))
        solved_ratio.append(solve / total)
        solved_issues.add(issue)

In [None]:
len(solved), round(sum(solved_ratio) / len(solved_ratio), 3)

In [None]:
sorted_issues = sorted(solved_issues, key=lambda x: sum([1 for context, status in filtered_trajs_final_context[x] if status]) / len(filtered_trajs_final_context[x]))

### Single Comparison

In [None]:
prompt_template = '''You are an expert software engineering evaluator tasked with analyzing and selecting the best solution patch for a GitHub issue. Your goal is to thoroughly evaluate each proposed patch and determine which one most effectively solves the issue.

First, examine the proposed patches:

<patches>
{trajs}
</patches>

Now, carefully read the following GitHub issue:

<github_issue>
{issue}
</github_issue>

Your task is to evaluate each patch based on its effectiveness in fixing the bug. Use the following scoring system:

1. Bug Fixing Score (0-2):
   0: Incorrect changes that won't fix the issue
   1: Partially correct changes (might fix the issue but misses corner cases)
   2: Correct changes that will fully fix the issue

2. Regression Risk (Score 0-2):
   0: High risk of introducing multiple regression bugs
   1: Moderate risk of introducing a few regression bugs
   2: Low risk of introducing any regression bugs

For each patch, provide your analysis using the following structure:

<patch_analysis>
  <patch_number>[Patch number]</patch_number>
  <bug_fixing_analysis>
    [Your analysis of the bug fixing approach]
    <score>[Score (0-2)]</score>
  </bug_fixing_analysis>
  <regression_risk_analysis>
    [Your analysis of potential regression risks]
    <score>[Score (0-2)]</score>
  </regression_risk_analysis>
</patch_analysis>

In your analysis, consider:
1. How well the patch addresses the core issue described in the GitHub issue?
2. Will the patch introduce regression bugs or other issues?

Please reason step by step, and put your final answer in the following format:

<best_patch>Number of the best patch</best_patch>

Important Tips:
1. For each patch, write down the main changes and their potential impact
2. Compare patches side-by-side
3. Consider potential edge cases and how each patch addresses them
4. Please do not prioritize one patch over another based on the position in the list. Evaluate each patch independently.
'''

In [None]:
data = datasets.load_dataset("princeton-nlp/SWE-bench_Lite", split="test")
df = pd.DataFrame(data)
dct = {}
for i, row in df.iterrows():
    dct[row['instance_id']] = row['problem_statement']

In [None]:
def parse_best_trajectory(text):
    # Look for the best trajectory tag
    start_tag = "<best_patch>"
    end_tag = "</best_patch>"
    
    try:
        # Find the start and end positions of the content
        start_pos = text.find(start_tag) + len(start_tag)
        end_pos = text.find(end_tag)
        
        # Extract and return the content between tags
        if start_pos >= 0 and end_pos >= 0:
            return text[start_pos:end_pos].strip()
        else:
            return text
            
    except Exception as e:
        print(f"Error parsing best trajectory: {e}")
        return text

In [None]:
target_issues = sorted_issues

In [None]:
logging.getLogger('httpx').setLevel(logging.ERROR)
logging.getLogger('LiteLLM').setLevel(logging.ERROR)

nest_asyncio.apply()
thread_pool = ThreadPoolExecutor(max_workers=16)

async def process_single_issue(issue, sem, filtered_trajs_final_context, dct, prompt_template, completion):
    async with sem:
        context = "\n".join([f"Solution Patch {idx + 1}: {i[0]}" 
                            for idx, i in enumerate(filtered_trajs_final_context[issue])])
        prompt = prompt_template.format(issue=dct[issue], trajs=context)
        messages = [
            {
                "role": "user",
                "content": prompt
            }
        ]
        model_params = {
            "model": "gpt-4o",
            "messages": messages,
            "temperature": 0.6,
            "top_p": 0.95,
        }
        try:
            # Run the synchronous completion in the thread pool
            loop = asyncio.get_event_loop()
            response = await loop.run_in_executor(
                thread_pool,
                lambda: completion(**model_params)
            )
            content = response.choices[0].message.content
            # best_trajectory = parse_best_trajectory(content)
            best_trajectory = content
            return issue, best_trajectory
        except Exception as e:
            print(f"Error processing issue {issue}: {str(e)}")
            return issue, None

async def run_parallel_processing():
    sem = Semaphore(16)  # Limit to 16 concurrent tasks
    tasks = []
    best_trajs = {}
    
    # Create all tasks
    for issue in target_issues:
        task = process_single_issue(
            issue, sem, filtered_trajs_final_context, dct, 
            prompt_template, completion
        )
        tasks.append(task)
    
    # Process tasks with progress bar
    for future in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
        issue, best_trajectory = await future
        if best_trajectory is not None:
            best_trajs[issue] = best_trajectory
    
    return best_trajs

best_trajs = await run_parallel_processing()
thread_pool.shutdown(wait=True)

In [None]:
rem = []
correct = set()
for issue in target_issues:
    try:
        pred = int(parse_best_trajectory(best_trajs[issue]))
        if filtered_trajs_final_context[issue][pred - 1][1]:
            correct.add(issue)
    except:
        print(issue)
        rem.append(issue)
len(correct)

#### Parse remaining

In [None]:
nest_asyncio.apply()
thread_pool = ThreadPoolExecutor(max_workers=16)

async def process_single_issue(issue, sem, eval_out, completion):
    async with sem:
        prompt = "This is the evaluation output for the patches but this is not parse properly. Here is the output:\n<evaluation_result>\n{evaluation_result}\n</evaluation_result>\n I want you to give me parse output so that I can evaluate the patches properly. \n For that please just give me the best patch number using the following output schema <best_patch>Number of the best patch</best_patch>.".format(evaluation_result=eval_out)
        messages = [
            {
                "role": "user",
                "content": prompt
            }
        ]
        model_params = {
            "model": "gpt-4o",
            "messages": messages,
            "temperature": 0.0,
            "top_p": 0.95
        }
        try:
            loop = asyncio.get_event_loop()
            response = await loop.run_in_executor(
                thread_pool,
                lambda: completion(**model_params)
            )
            content = response.choices[0].message.content
            best_trajectory = parse_best_trajectory(content)
            return issue, best_trajectory
        except Exception as e:
            print(f"Error processing issue {issue}: {str(e)}")
            return issue, None

async def run_parallel_processing():
    sem = Semaphore(8)  # Limit to 16 concurrent tasks
    tasks = []
    final_best_trajs = {}
    
    # Create all tasks
    for issue in rem:
        best_traj = best_trajs[issue] if issue in best_trajs else ""
        task = process_single_issue(
            issue, sem, best_traj, completion
        )
        tasks.append(task)
    
    # Process tasks with progress bar
    for future in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
        issue, best_trajectory = await future
        if best_trajectory is not None:
            final_best_trajs[issue] = best_trajectory
    
    return final_best_trajs

final_best_trajs = await run_parallel_processing()
thread_pool.shutdown(wait=True)

In [None]:
for issue in rem:
    try:
        pred = int(final_best_trajs[issue])
        if filtered_trajs_final_context[issue][pred - 1][1]:
            correct.add(issue)
    except:
        print(issue)
len(correct)

### Pair-wise Comparison

In [None]:
nest_asyncio.apply()
thread_pool = ThreadPoolExecutor(max_workers=16)

async def compare_pair(traj1_idx, traj2_idx, issue, sem, filtered_trajs_final_context, dct, prompt_template, completion, round_num, comparison_results):
    """Compare two trajectories and return the index of the better one along with the comparison details."""
    async with sem:
        # Format the two trajectories for comparison
        traj1 = filtered_trajs_final_context[issue][traj1_idx - 1][0]
        traj2 = filtered_trajs_final_context[issue][traj2_idx - 1][0]
        
        context = f"Solution Description 1: {traj1}\nSolution Description 2: {traj2}"
        prompt = prompt_template.format(issue=dct[issue], trajs=context)

        messages = [
            {
                "role": "user",
                "content": prompt
            }
        ]
        
        model_params = {
            "model": "gpt-4o",
            "messages": messages,
            "temperature": 0.0,
            "top_p": 0.95,
        }
        
        try:
            loop = asyncio.get_event_loop()
            response = await loop.run_in_executor(
                thread_pool,
                lambda: completion(**model_params)
            )
            content = response.choices[0].message.content
            best_trajectory = parse_best_trajectory(content)
            winner = traj1_idx if best_trajectory == 1 else traj2_idx
            
            # Store comparison results
            comparison_info = {
                "round": round_num,
                "traj1_idx": traj1_idx,
                "traj2_idx": traj2_idx,
                "model_output": content,
                "winner": winner
            }
            comparison_results.append(comparison_info)
            
            return winner
        except Exception as e:
            print(f"Error comparing trajectories {traj1_idx} and {traj2_idx} for issue {issue}: {str(e)}")
            # Store error information
            comparison_info = {
                "round": round_num,
                "traj1_idx": traj1_idx,
                "traj2_idx": traj2_idx,
                "error": str(e),
                "winner": traj1_idx  # Default to first trajectory in case of error
            }
            comparison_results.append(comparison_info)
            return traj1_idx

async def process_single_issue_tournament(issue, sem, filtered_trajs_final_context, base_save_dir, dct, prompt_template, completion):
    """Process a single issue using tournament-style elimination and save comparison results."""
    num_trajs = len(filtered_trajs_final_context[issue])
    current_round = list(range(1, num_trajs + 1))
    round_num = 1
    
    save_dir = f"{base_save_dir}/{issue}"
    os.makedirs(save_dir, exist_ok=True)
    
    all_rounds_results = []
    
    while len(current_round) > 1:
        print(f"Issue {issue} - Round {round_num} - {len(current_round)} trajectories remaining")
        next_round = []
        round_results = []
        
        # Create pairs for this round
        for i in range(0, len(current_round), 2):
            if i + 1 >= len(current_round):
                next_round.append(current_round[i])
                continue
                
            # Compare pair and get winner
            winner = await compare_pair(
                current_round[i], 
                current_round[i + 1],
                issue,
                sem,
                filtered_trajs_final_context,
                dct,
                prompt_template,
                completion,
                round_num,
                round_results
            )
            next_round.append(winner)
        
        # Save results for this round
        round_filename = os.path.join(save_dir, f"round_{round_num}_results.json")
        with open(round_filename, 'w') as f:
            json.dump({
                "round_number": round_num,
                "timestamp": datetime.now().isoformat(),
                "comparisons": round_results,
                "advancing_trajectories": next_round
            }, f, indent=2)
        
        all_rounds_results.extend(round_results)
        current_round = next_round
        round_num += 1
    
    # Save complete tournament results
    final_filename = os.path.join(save_dir, "complete_tournament_results.json")
    with open(final_filename, 'w') as f:
        json.dump({
            "issue": issue,
            "total_rounds": round_num - 1,
            "winner": current_round[0],
            "all_comparisons": all_rounds_results
        }, f, indent=2)
    
    return issue, current_round[0]

async def run_parallel_tournament(base_save_dir):
    sem = Semaphore(16)  # Limit to 16 concurrent tasks
    tasks = []
    best_trajs = {}
    
    # Create tasks for each issue
    for issue in target_issues:
        task = process_single_issue_tournament(
            issue, sem, filtered_trajs_final_context, base_save_dir, dct, prompt_template, completion
        )
        tasks.append(task)
    
    # Process issues in parallel with progress bar
    for future in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
        issue, best_trajectory_idx = await future
        best_trajs[issue] = best_trajectory_idx
    
    return best_trajs

# Execute the tournament processing
best_trajs = await run_parallel_tournament(base_save_dir='/path/to/save_results/')

# Clean up the thread pool when done
thread_pool.shutdown(wait=True)

In [None]:
correct = set()
for issue in target_issues:
    try:
        pred = best_trajs[issue]
        if filtered_trajs_final_context[issue][pred - 1][1]:
            correct.add(issue)
    except:
        print(issue)
len(correct)