In [14]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage
from typing import Dict
import json

In [15]:
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    google_api_key="AIzaSyC05prna34uJxLl9xgnLJUfBphgYKKQlc0",
    temperature=0
)

# llm = ChatDeepSeek(
#     model="deepseek-chat",
#     google_api_key="sk-fd9f23bb13344af497899e25ce4327aa"
# )

In [16]:
from typing import Dict, List
import logging

In [17]:
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [None]:
def build_initial_prompt(problem: str, context: Dict) -> str:
    base_info = context.get('base_commit_info', '')
    test_snippet = context.get('failing_test', '')
    return (
        f"""You are an expert software engineer.

Below is a problem description from a GitHub issue:
{problem}
Make sure you are checking the correct base commit version or else the patch you generate will be invalid. You can check that from reading this {instance}. 

Generate a fix for this issue in valid unified-diff format.

Make sure to fix each and every error. Go Through all the lines pertaining to the problem and check. Try simulating a test case yourself to see if the code you have written gives ocrrect output.
Make sure to check the file path and make sure it is correct to avoid context mismatch errors 

Your output must:
- Be in pure raw unified diff format. There should be no wrappers like the ```
- Begin with '--- a/...' and '+++ b/...'
- Contain at least one '@@' hunk
- NOT include explanations, comments, or markdown
- Be minimal and syntactically correct
- Make sure the code which you are changing (the - part in the code) is actually there
- Check if the indices of the line of code and the actual code in the repository actually match

If the above does not work, try again.

Make sure you are checking the correct base commit version or else the patch you generate will be invalid, here is all the relevant info {instance}

Make sure to fix each and every error. Go Through all the lines pertaining to the problem and check. Try simulating a test case yourself to see if the code you have written gives ocrrect output.
Make sure to check the file path and make sure it is correct to avoid context mismatch errors 

Output only the patch."""
    )

In [19]:
def build_reflection_prompt(patch: str, problem: str) -> str:
    return (
        f"""
Review this patch intended to fix the issue:

{patch}

Checks:
- Valid unified diff syntax (--- a/, '@@', '+' and '-' marks)
- Addresses the problem correctly
- Apply cleanly without context mismatches

If you find issues, rewrite and return only the corrected unified diff (no explanations).
"""
    )


In [20]:
def is_valid_diff(patch: str) -> bool:
    return patch.strip().startswith('--- a/') and '@@' in patch

In [21]:
def extract_context(instance: Dict) -> Dict:
    # Only include necessary fields
    return {
        'base_commit_info': instance.get('base_commit_info', ''),
        'failing_test': instance.get('test_patch', '')
    }

In [22]:
def generate_initial_patch(problem: str, context: Dict, llm) -> str:
    prompt = build_initial_prompt(problem, context)
    patch = llm([HumanMessage(content=prompt)]).content
    logger.info("Initial patch generated")
    return patch

In [23]:

def reflect_and_improve(patch: str, problem: str, max_iter: int, llm) -> str:
    for i in range(max_iter):
        if is_valid_diff(patch):
            logger.info(f"Patch passed validation on iteration {i}")
            break
        prompt = build_reflection_prompt(patch, problem)
        patch = llm([HumanMessage(content=prompt)]).content
        logger.info(f"Reflection iteration {i+1} completed")
    return patch

In [24]:

def refine_patch_format(patch: str) -> str:
    """
    Ensures a patch is in valid unified-diff format. If malformed, uses an LLM to correct it.
    """
    # quick heuristic: must start with '--- a/' and contain '@@'
    if not patch.startswith('--- a/') or '@@' not in patch:
        prompt = f"""You are a software engineer fixing an automated code patch.

The following patch is malformed or incomplete: {patch}

Please return a corrected patch using valid unified-diff format:
- It must start with '--- a/...' and '+++ b/...'
- It must contain at least one valid hunk beginning with '@@'
- Do not include explanations, comments, markdown, or any text outside the patch.
- Your response must be a pure, corrected unified diff.
- It should be in raw unified-diff format, without any markdown wrapping

Return a minimal valid patch only."""
        feedback = llm([HumanMessage(content=prompt)]).content
        return feedback
    return patch

In [25]:
def generate_patch(instance: Dict, llm, max_reflections: int = 3) -> str:
    problem = instance['problem_statement']
    context = extract_context(instance)

    # Stage 1: initial patch
    patch = generate_initial_patch(problem, context, llm)

    # Stage 2: iterative self-reflection
    patch = reflect_and_improve(patch, problem, max_reflections, llm)

    # Stage 3: final cleanup
    return refine_patch_format(patch)

In [None]:
from datasets import load_dataset
dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split="test")
dataset = dataset.remove_columns("patch")
dataset = dataset.remove_columns("test_patch")
dataset = dataset.shuffle()
print(dataset.column_names)
predictions = []
for i in range(10):
    instance = dataset[i]
    patch = generate_patch(instance,llm)
    predictions.append({
        "instance_id": instance["instance_id"],
        "model_name_or_path": "my-multi-llm-agent",
        "model_patch": patch
    })
    print(i+1)
    i += 1

with open("my_preds.jsonl", "w") as f:
    for p in predictions:
        f.write(json.dumps(p) + "\n")

['repo', 'instance_id', 'base_commit', 'test_patch', 'problem_statement', 'hints_text', 'created_at', 'version', 'FAIL_TO_PASS', 'PASS_TO_PASS', 'environment_setup_commit']


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


1


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


2


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


3


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


4


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


5


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


6


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


7


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


8


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


9


INFO:__main__:Initial patch generated
INFO:__main__:Reflection iteration 1 completed
INFO:__main__:Reflection iteration 2 completed
INFO:__main__:Reflection iteration 3 completed


10
