In [1]:
import re
import numpy as np
import pandas as pd
import datasets
import openai 
import concurrent.futures

from typing import List, Dict
from tqdm import tqdm
from openai import OpenAI
from datasets import load_dataset

def multithread_openai_call(client, messages, model_name, max_workers=20, **kwargs):
    
    def call_openai(message: List[Dict[str, str]], **kwargs):
        response = client.chat.completions.create(
            model=model_name, 
            messages=message,
            **kwargs
        )
        # Handle multiple choices when n > 1
        if hasattr(response, 'choices') and len(response.choices) > 1:
            return [choice.message.content for choice in response.choices]
        else:
            return response.choices[0].message.content

    responses = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_message = {executor.submit(call_openai, message, **kwargs): i for i, message in enumerate(messages)}
        
        with tqdm(total=len(messages), desc="Processing messages") as pbar:
            for future in concurrent.futures.as_completed(future_to_message):
                message_idx = future_to_message[future]
                response = future.result()
                responses.append((message_idx, response))
                pbar.update(1)
    
    # Sort responses by original index and extract just the responses
    responses.sort(key=lambda x: x[0])
    return [r[1] for r in responses] 

client = OpenAI(base_url="http://22.1.61.18:30000/v1", api_key="EMPTY")

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
train_data = load_dataset("parquet", data_files="data/swe-oracle-search-replace/swe-oracle-search-replace-train.parquet", split="train")
train_data[0]

{'instance_id': 'ipython__ipython-10213',
 'repo': 'ipython/ipython',
 'base_commit': '78ec96d7ca0147f0655d5260f2ab0c61d94e4279',
 'problem_statement': 'remove usage of backports.shutil_get_terminal_size\nThis is for pre-3.3 Python.\r\n\r\nPretty easy it should only require deleting lines. \r\nMaybe a few need to be dedented.\n',
 'hints_text': '',
 'created_at': '2017-01-28T05:22:06Z',
 'test_patch': '',
 'version': '',
 'FAIL_TO_PASS': '[]',
 'PASS_TO_PASS': '[]',
 'environment_setup_commit': '',
 'file_names': ['README.rst', 'IPython/utils/terminal.py'],
 'qwen_input_length': 1875,
 'data_source': 'swe-oracle-search-replace',
   'role': 'user'}],
 'ability': 'swe',
 'reward_model': {'ground_truth': [['README.rst', 'IPython/utils/terminal.py'],
  'style': 'rule'},
  'file_names': ['README.rst', 'IPython/utils/terminal.py'],
  'index': 1,
  'split': 'train'}}

In [9]:
def fix_bug_file_contents(example):
    file_contents = example["file_contents"]
    example["extra_info"]["file_contents"] = file_contents
    return example

train_data = train_data.map(fix_bug_file_contents, num_proc=16)

Map (num_proc=16):  90%|████████▉ | 15555/17302 [00:12<00:09, 187.81 examples/s] 

Map (num_proc=16): 100%|██████████| 17302/17302 [00:18<00:00, 924.06 examples/s]


In [3]:
train_data[0]

{'instance_id': 'ipython__ipython-10213',
 'repo': 'ipython/ipython',
 'base_commit': '78ec96d7ca0147f0655d5260f2ab0c61d94e4279',
 'problem_statement': 'remove usage of backports.shutil_get_terminal_size\nThis is for pre-3.3 Python.\r\n\r\nPretty easy it should only require deleting lines. \r\nMaybe a few need to be dedented.\n',
 'hints_text': '',
 'created_at': '2017-01-28T05:22:06Z',
 'test_patch': '',
 'version': '',
 'FAIL_TO_PASS': '[]',
 'PASS_TO_PASS': '[]',
 'environment_setup_commit': '',
 'file_names': ['README.rst', 'IPython/utils/terminal.py'],
 'qwen_input_length': 1875,
 'data_source': 'swe-oracle-search-replace',
   'role': 'user'}],
 'ability': 'swe',
 'reward_model': {'ground_truth': [['README.rst', 'IPython/utils/terminal.py'],
  'style': 'rule'},
  'file_names': ['README.rst', 'IPython/utils/terminal.py'],
  'index': 1,
  'split': 'train'}}

In [3]:
AGENTLESS_REPAIR = """We are currently solving the following issue within our repository. Here is the issue text:
--- BEGIN ISSUE ---
{problem_statement}
--- END ISSUE ---

Below are some code segments, each from a relevant file. One or more of these files may contain bugs.

--- BEGIN FILE ---
```
{content}
```
--- END FILE ---

Please first localize the bug based on the issue statement, and then generate *SEARCH/REPLACE* edits to fix the issue.

Every *SEARCH/REPLACE* edit must use this format:
1. The file path
2. The start of search block: <<<<<<< SEARCH
3. A contiguous chunk of lines to search for in the existing source code
4. The dividing line: =======
5. The lines to replace into the source code
6. The end of the replace block: >>>>>>> REPLACE

Here is an example:

```python
### mathweb/flask/app.py
<<<<<<< SEARCH
from flask import Flask
=======
import math
from flask import Flask
>>>>>>> REPLACE
```

Please note that the *SEARCH/REPLACE* edit REQUIRES PROPER INDENTATION. If you would like to add the line '        print(x)', you must fully write that out, with all those spaces before the code!
Wrap each *SEARCH/REPLACE* edit in a code block as shown in the example above. If you have multiple *SEARCH/REPLACE* edits, use a separate code block for each one.

Your response format must follow the template below:
<think>
Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct solution.
</think>
<solution>
one or multiple SEARCH/REPLACE edits in code blocks.
</solution>"""

CODE_FILE = """
### {path}
{content}
""".strip()


In [4]:
def process_fn(example, idx):
    problem_statement = example["problem_statement"]
    file_names = example["file_names"]
    file_contents = example["file_contents"]

    file_contents = [CODE_FILE.format(path=name, content=content) for name, content in zip(file_names, file_contents)]
    file_contents = "\n\n".join(file_contents)
    prompt = AGENTLESS_REPAIR.format(problem_statement=problem_statement, content=file_contents)
    messages = [dict(role="user", content=prompt)]
        
    return {"prompt": messages}

train_data = train_data.map(process_fn, with_indices=True, num_proc=16)

Map (num_proc=16):  13%|█▎        | 2202/17302 [00:00<00:03, 4442.66 examples/s]

Map (num_proc=16): 100%|██████████| 17302/17302 [00:19<00:00, 883.20 examples/s] 


In [7]:
train_data.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-train.parquet")

Creating parquet from Arrow format:   0%|          | 0/18 [00:00<?, ?ba/s]

Creating parquet from Arrow format: 100%|██████████| 18/18 [00:53<00:00,  2.96s/ba]


10644285335

In [8]:
train_data_16k = train_data.filter(lambda x: x["qwen_input_length"] <= 12288)
print(len(train_data_16k))
train_data_16k.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-train-16k.parquet")
train_data_24k = train_data.filter(lambda x: x["qwen_input_length"] <= 20480)
print(len(train_data_24k))
train_data_24k.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-train-24k.parquet")
train_data_32k = train_data.filter(lambda x: x["qwen_input_length"] <= 4096*7)
print(len(train_data_32k))
train_data_32k.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-train-32k.parquet")

Filter: 100%|██████████| 17302/17302 [00:08<00:00, 1943.24 examples/s]


8300


Creating parquet from Arrow format: 100%|██████████| 9/9 [00:05<00:00,  1.78ba/s]
Filter: 100%|██████████| 17302/17302 [00:07<00:00, 2187.77 examples/s]


10910


Creating parquet from Arrow format: 100%|██████████| 11/11 [00:08<00:00,  1.30ba/s]
Filter: 100%|██████████| 17302/17302 [00:08<00:00, 2100.08 examples/s]


12491


Creating parquet from Arrow format: 100%|██████████| 13/13 [00:11<00:00,  1.13ba/s]


2819599541

In [16]:
test_data = load_dataset("parquet", data_files="data/swe-oracle-search-replace/swe-oracle-search-replace-test.parquet", split="train")
test_data[0]

{'instance_id': 'astropy__astropy-11693',
 'repo': 'astropy/astropy',
 'base_commit': '3832210580d516365ddae1a62071001faf94d416',
 'created_at': '2021-05-04T10:05:33Z',
 'version': '4.2',
 'PASS_TO_PASS': '["astropy/wcs/wcsapi/tests/test_fitswcs.py::test_empty", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_simple_celestial", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[tai]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[tcb]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[tcg]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[tdb]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[tt]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[ut1]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[utc]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values[local]", "astropy/wcs/wcsapi/tests/test_fitswcs.py::test_time_1d_values_gps", "astropy/wcs/wcsapi/tests/test_fitswcs.p

In [12]:

test_data = test_data.map(process_fn, with_indices=True, num_proc=16)
test_data.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-test.parquet")

Map (num_proc=16):  85%|████████▌ | 1919/2248 [00:02<00:00, 1601.92 examples/s]

Map (num_proc=16): 100%|██████████| 2248/2248 [00:02<00:00, 844.35 examples/s] 
Creating parquet from Arrow format: 100%|██████████| 3/3 [00:03<00:00,  1.07s/ba]


906345732

In [18]:
test_data_16k = test_data.filter(lambda x: x["qwen_input_length"] <= 4096*3)
print(len(test_data_16k))
test_data_16k.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-test-16k.parquet")
test_data_24k = test_data.filter(lambda x: x["qwen_input_length"] <= 4096*5)
print(len(test_data_24k))
test_data_24k.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-test-24k.parquet")
test_data_32k = test_data.filter(lambda x: x["qwen_input_length"] <= 4096*7)
print(len(test_data_32k))
test_data_32k.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-test-32k.parquet")

1185


Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00,  2.99ba/s]


1676


Creating parquet from Arrow format: 100%|██████████| 2/2 [00:01<00:00,  1.52ba/s]


1925


Creating parquet from Arrow format: 100%|██████████| 2/2 [00:01<00:00,  1.10ba/s]


488642420

In [5]:
test_num = 100
selected_data = test_data.select(range(test_num))
messages = selected_data["prompt"]
responses = multithread_openai_call(client, messages, "default", max_workers=32, temperature=0.6, max_tokens=4096)
selected_data = selected_data.add_column("response", responses)

Processing messages:   0%|          | 0/100 [00:00<?, ?it/s]

Processing messages: 100%|██████████| 100/100 [03:38<00:00,  2.19s/it]
Flattening the indices: 100%|██████████| 100/100 [00:00<00:00, 1339.32 examples/s]


In [8]:
def compute_rewards(sample):
    code_dict = {name: content for name, content in zip(sample["extra_info"]["file_names"], sample["extra_info"]["file_contents"])}
    oracle_dict = {name: content for name, content in zip(sample["reward_model"]["ground_truth"][0], sample["reward_model"]["ground_truth"][1])}
    response = sample["response"]
    try:
        thought, answer = extract_thought_solution(response)
        patch = parse_search_replace(answer)
        pred_dict = apply_code_change(code_dict, patch)
        if len(pred_dict) == 0:
            raise FormatError("No valid search blocks found")
        reward, info = calculate_reward(code_dict, oracle_dict, pred_dict)
    except FormatError as e:
        reward, info = 0.0, dict(error=str(e))
    sample["reward"] = reward
    return sample

selected_data = selected_data.map(compute_rewards, num_proc=16)


Map (num_proc=16):   0%|          | 0/100 [00:00<?, ? examples/s]

Map (num_proc=16): 100%|██████████| 100/100 [00:01<00:00, 89.05 examples/s]


In [12]:
selected_data.to_parquet("data/swe-oracle-search-replace/swe-oracle-search-replace-test-rollout.parquet")

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 18.41ba/s]


13932934

In [9]:
print(np.mean(selected_data["reward"]))
print(sum([True for r in selected_data["reward"] if r != 0.0]))

0.21446950282809485
46


In [24]:
print(code_dict)



In [25]:
response = sample["response"]
patch = parse_search_replace(response)
pred_dict = apply_code_change(code_dict, patch)

In [28]:
print(response)

<solution>
### compose/project.py
<<<<<<< SEARCH
            try:
                # this can fail if the container has been removed
                container = Container.from_id(self.client, event['id'])
            except APIError:
                continue
            if container.service not in service_names:
                continue
            yield build_container_event(event, container)
            try:
                # this can fail if the container has been removed
                container = Container.from_id(self.client, event['id'])
                event_labels = container.labels
            except APIError:
                continue
            if container.service not in service_names:
                continue
            yield {
                'time': datetime.datetime.fromtimestamp(event['time']),
                'type': 'container',
                'action': event['status'],
                'id': container.id,
                'service': container.service,
             

In [7]:
# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.

import difflib
import os
import re
import warnings
from typing import TypedDict

from pprint import pprint
from unidiff import PatchedFile, PatchSet
from unidiff.errors import UnidiffParseError

THINK_START = "<think>"
THINK_END = "</think>"
ANSWER_START = "<solution>"
ANSWER_END = "</solution>"

SEARCH_REPLACE_REGEX = r"```.*?\n### (.*)\n<<<<<<< SEARCH\n([\s\S]*?)\n=======\n([\s\S]*?)\n>>>>>>> REPLACE\n```"


class FormatError(Exception):
    pass


def extract_thought_solution(output: str) -> tuple[str, str]:
    """
    Extract the thought and solution from the output. It is expected to have the following format:
    <think>
    ...
    </think>
    <solution>
    ...
    </solution>
    """
    for tag in [THINK_START, THINK_END, ANSWER_START, ANSWER_END]:
        if output.count(tag) != 1:
            raise FormatError(f"count of {tag} is not 1")

    thought = output.split(THINK_START)[1].split(THINK_END)[0].strip()
    answer = output.split(ANSWER_START)[1].split(ANSWER_END)[0].strip()
    if len(thought) == 0:
        raise FormatError("Thought is empty")
    return thought, answer


def parse_search_replace(text: str) -> dict[str, list[tuple[str, str]]]:
    """
    Parse the search/replace blocks from the text.

    Returns:
        A dictionary where the key is the file path and the value is a list of search/replace pairs.
    """
    path_search_replaces: list[tuple[str, str, str]] = re.findall(
        SEARCH_REPLACE_REGEX, text
    )
    path_search_replace_dict = dict[str, list[tuple[str, str]]]()
    for path, search, replace in path_search_replaces:
        path_search_replace_dict.setdefault(path, []).append((search, replace))
    return path_search_replace_dict


def generate_unified_diff(
    old_code: str,
    new_code: str,
    n_context: int = 3,
) -> str:
    """Generate a unified diff between two code.

    Args:
        old_code: The original code.
        new_code: The modified code.
        n_context: The number of context lines to show.

    Returns:
        A string representing the unified diff."""

    original_lines = old_code.splitlines()
    modified_lines = new_code.splitlines()

    diff = difflib.unified_diff(
        original_lines,
        modified_lines,
        fromfile="old",
        tofile="new",
        lineterm="",
        n=n_context,
    )
    try:
        next(diff)
        next(diff)
        diff_code = "\n".join(diff)
        return diff_code
    except StopIteration:
        return ""


def apply_code_change(
    code_context: dict[str, str],
    search_replace_dict: dict[str, list[tuple[str, str]]],
    silent: bool = False,
) -> dict[str, str]:
    """
    Apply the search/replace edits to the code context.

    Args:
        code_context: A dictionary containing the file path and the content of the code.
        search_replace_dict: A dictionary mapping the file path to the search/replace edits.
        silent: Whether to suppress the error messages.

    Returns:
        A dictionary containing the file path and the new content of the code.
    """
    new_content_dict = dict[str, str]()
    for path, search_replaces in search_replace_dict.items():
        new_content = "\n" + code_context.get(path, "")
        for search, replace in search_replaces:
            # Ensure search block can be matched
            # "\n" + search to ensure the indentations are correct
            if not silent and len(search) == len(replace) and search == replace:
                raise FormatError("Search and replace blocks are identical")
            search = "\n" + search
            replace = "\n" + replace
            if not silent and search not in new_content:
                raise FormatError(f"Search block not found in the code: {search}")
            new_content = new_content.replace(search, replace)
        # Remove the leading "\n"
        new_content_dict[path] = new_content[1:]
    return new_content_dict


def get_normalized_patch(
    code_context: dict[str, str],
    new_content_dict: dict[str, str],
) -> dict[str, str]:
    """
    According to the code context and new content, generate the normalized patch for each file.

    Args:
        code_context: A dictionary containing the file path and the content of the code.
        new_content_dict: A dictionary mapping the file path to the new content of the file.

    Returns:
        A dictionary containing the file path and the normalized patch.
    """
    patch_dict = dict[str, str]()
    for path, new_content in new_content_dict.items():
        old_content = code_context.get(path, "")
        patch = generate_unified_diff(old_content, new_content)
        # Only add the patch if it's not empty
        # NOTE: this should not happen due to the search == replace check in `apply_code_change`
        # but it can occur in general-purpose usages
        if patch:
            patch_dict[path] = patch
    return patch_dict


class ChangeSimilarity(TypedDict):
    path: str
    pred_change: str
    oracle_change: str
    similarity: float


def compute_change_similarities(
    pred_patch: dict[str, str],
    oracle_patch: dict[str, str],
) -> list[ChangeSimilarity]:
    all_file_paths = set(oracle_patch.keys()).union(set(pred_patch.keys()))
    similarities = list[ChangeSimilarity]()
    for path in all_file_paths:
        pred_change = pred_patch.get(path, "")
        oracle_change = oracle_patch.get(path, "")
        if oracle_change == "" or pred_change == "":
            # Both are empty changes, meaning search = replace. We should penalize this to avoid
            # the model predicting empty changes to hack the reward.
            # NOTE: this should not happen due to (1) the search == replace check in `apply_code_change`
            # and (2) the `if patch` check in `get_normalized_patch`.
            change_similarity = 0.0
        else:
            change_similarity = difflib.SequenceMatcher(
                None,
                pred_change,
                oracle_change,
                autojunk=False,
            ).ratio()
        similarities.append(
            ChangeSimilarity(
                path=path,
                pred_change=pred_change,
                oracle_change=oracle_change,
                similarity=change_similarity,
            )
        )
    return similarities


def calculate_reward(
    code_context: dict[str, str],
    oracle_new_content: dict[str, str],
    pred_new_content: dict[str, str],
) -> tuple[float, dict]:
    """
    Compute the SWE-RL reward given the code context, oracle patch, and the model output.
    Note that this function is a general version of the reward calculation, which can be used
    for code changes in any form, not just search/replace edits. For search/replace edits, use
    `calculate_search_replace_reward`.

    The return value is always within the range of [0, 1].

    Args:
        code_context: path -> original content of the file. It doesn't need to
            contain the entire codebase, only the files that are affected by the oracle patch.
        oracle_new_content: path -> oracle new content of the file after change.
        pred_new_content: path -> predicted new content of the file after change.

    Returns:
        A float value representing the reward, and a dictionary containing some metadata.
    """
    # Obtain a unified diff for each file, for both the predicted and the oracle patch
    oracle_patch = get_normalized_patch(code_context, oracle_new_content)
    pred_patch = get_normalized_patch(code_context, pred_new_content)
    # Calculate the reward based on the similarity between the predicted and the oracle patch
    similarities = compute_change_similarities(pred_patch, oracle_patch)
    # assert len(similarities) > 0
    # This means oracle_patch and pred_patch are both empty, then they are identical and we reward 1.0
    if len(similarities) == 0:
        assert len(oracle_patch) == 0 and len(pred_patch) == 0
        return 1.0, dict(similarities=[])
    reward = sum(map(lambda x: x["similarity"], similarities)) / len(similarities)
    return reward, dict(similarities=similarities)


def calculate_search_replace_reward(
    code_context: dict[str, str],
    oracle_new_content: dict[str, str],
    output: str,
) -> tuple[float, dict]:
    """
    The search/replace version of the reward calculation. It expects the output to contain
    the thought and solution in the following format:
    <think>
    ...
    </think>
    <solution>
    ...
    </solution>

    Args:
        code_context: path -> original content of the file.
        oracle_new_content: path -> oracle new content of the file after change.
        output: The output from the model containing the thought and solution.

    Returns:
        A float value representing the reward, and a dictionary containing some metadata.
    """
    try:
        # Extract the thought and solution from the output
        thought, answer = extract_thought_solution(output)
        # Parse the search/replace edits from the solution
        pred_search_replaces = parse_search_replace(answer)
        if len(pred_search_replaces) == 0:
            raise FormatError("No valid search blocks found")
        # Get the new content of each file after applying the search/replace edits
        pred_new_content = apply_code_change(code_context, pred_search_replaces)
        reward, metadata = calculate_reward(
            code_context, oracle_new_content, pred_new_content
        )
        metadata["thought"] = thought
        metadata["answer"] = answer
        return reward, metadata
    except FormatError as e:
        return -1.0, dict(error=str(e))


def get_filelevel_diff(patch_text: str) -> dict[str, str]:
    """
    Convert a unified diff text into a dictionary of file patches.
    """
    try:
        patch = PatchSet(patch_text)
    except UnidiffParseError:
        return {}
    except Exception as e:
        # NOTE: sometimes unidiff throws other exceptions (e.g. UnboundLocalError) than
        # UnidiffParseError, which is unexpected, but we should still handle it.
        warnings.warn(f"Unexpected unidiff parsing error: {str(e)}")
        return {}
    result = dict[str, str]()
    for patchfile in patch:
        patchfile: PatchedFile = patchfile
        if patchfile.is_binary_file:
            # We don't consider binary files
            continue
        if patchfile.is_rename:
            # Add a special header for renamed files
            source_file = patchfile.source_file
            target_file = patchfile.target_file
            if source_file.startswith("a/"):
                source_file = source_file[2:]
            if target_file.startswith("b/"):
                target_file = target_file[2:]
            header = f"rename from {source_file} to {target_file}"
            path = source_file
        else:
            header = ""
            path = patchfile.path
        body = "\n".join(str(hunk).strip() for hunk in patchfile)
        content = header + "\n" + body
        content = content.strip()
        result[path] = content
    return result


def calculate_reward_unidiff(
    oracle_patches: list[str], pred_patches: list[str]
) -> tuple[float, dict]:
    """
    Compute the SWE-RL reward given two sets of unified diffs.

    The return value is always within the range of [0, 1].

    Args:
        oracle_patches: A list of oracle diffs.
        pred_patches: A list of predicted diffs.

    Returns:
        A float value representing the reward, and a dictionary containing some metadata.
    """
    # Calculate the reward based on the similarity between the predicted and the oracle patch
    pred_patch_dict = dict[str, str]()
    oracle_patch_dict = dict[str, str]()

    for patch_text in oracle_patches:
        oracle_patch_dict.update(get_filelevel_diff(patch_text))

    for patch_text in pred_patches:
        pred_patch_dict.update(get_filelevel_diff(patch_text))

    similarities = compute_change_similarities(pred_patch_dict, oracle_patch_dict)
    if len(similarities) == 0:
        assert len(pred_patch_dict) == 0 and len(oracle_patch_dict) == 0
        return 1.0, dict(similarities=[])
    reward = sum(map(lambda x: x["similarity"], similarities)) / len(similarities)
    return reward, dict(similarities=similarities)


def swe_rl_unidiff_score(data_source, solution_str, ground_truth, extra_info=None):
    
    def extract_patch_content(solution_str):
        patch_content = re.findall(r'<patch>(.*?)</patch>', solution_str, re.DOTALL)
        return patch_content

    patch_content = extract_patch_content(solution_str)

    if len(patch_content) == 0:
        return 0.0
    else:
        solution = f"<patch>{patch_content[-1]}</patch>"
        return calculate_reward_unidiff([ground_truth], [solution])[0]


def swe_rl_search_replace_score(data_source, solution_str, ground_truth, extra_info=None):
    file_names = extra_info["file_names"]
    file_contents = extra_info["file_contents"]
    code_dict = {name: content for name, content in zip(file_names, file_contents)}
    oracle_file_names, oracle_file_contents = ground_truth
    oracle_code_dict = {name: content for name, content in zip(oracle_file_names, oracle_file_contents)}
    reward, info = calculate_search_replace_reward(code_dict, oracle_code_dict, solution_str)
    return reward