Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions baselines/commit0_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import git
import os
import re
import subprocess
from dataclasses import asdict
from pathlib import Path
from typing import List
import fitz

from baselines.class_types import AgentConfig

Expand All @@ -13,7 +13,7 @@
REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n"
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"

SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n"
# prefix components:
space = " "
branch = "│ "
Expand Down Expand Up @@ -122,14 +122,14 @@ def get_target_edit_files(target_dir: str) -> list[str]:
"""Find the files with the error 'NotImplementedError('IMPLEMENT ME
HERE')'.
"""
# The grep command
command = f"grep -R -l \"NotImplementedError('IMPLEMENT ME HERE')\" {target_dir}"

# Run the command and capture the output
result = subprocess.run(command, shell=True, capture_output=True, text=True)

# Split the output into lines and remove the base_dir prefix
files = result.stdout.strip().split("\n")
files = []
for root, _, filenames in os.walk(target_dir):
for filename in filenames:
if filename.endswith(".py"):
file_path = os.path.join(root, filename)
with open(file_path, "r") as file:
if "NotImplementedError('IMPLEMENT ME HERE')" in file.read():
files.append(file_path)

# Remove the base_dir prefix
files = [file.replace(target_dir, "").lstrip("/") for file in files]
Expand All @@ -143,7 +143,8 @@ def get_target_edit_files(target_dir: str) -> list[str]:
def get_message(
agent_config: AgentConfig,
repo_path: str,
test_dir: str,
test_dir: str | None = None,
test_file: str | None = None,
) -> str:
"""Get the message to Aider."""
prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt
Expand All @@ -157,6 +158,13 @@ def get_message(
include_stubs=True,
)[: agent_config.max_unit_tests_info_length]
)
elif agent_config.use_unit_tests_info and test_file:
unit_tests_info = (
f"\n{UNIT_TESTS_INFO_HEADER} "
+ get_file_info(
file_path=Path(os.path.join(repo_path, test_file)), prefix=""
)[: agent_config.max_unit_tests_info_length]
)
else:
unit_tests_info = ""

Expand All @@ -171,15 +179,34 @@ def get_message(
else:
repo_info = ""

message_to_agent = prompt + repo_info + unit_tests_info
if agent_config.use_spec_info:
spec_info = (
f"\n{SPEC_INFO_HEADER} "
+ get_specification(specification_pdf_path=Path(repo_path, "spec.pdf"))[
: agent_config.max_spec_info_length
]
)
else:
spec_info = ""

message_to_agent = prompt + repo_info + unit_tests_info + spec_info

return message_to_agent


def get_reference(specification_pdf_path: str) -> str:
def get_specification(specification_pdf_path: Path) -> str:
"""Get the reference for a given specification PDF path."""
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
return f"/pdf {specification_pdf_path}"
# Open the specified PDF file
document = fitz.open(specification_pdf_path)
text = ""

# Iterate through the pages
for page_num in range(len(document)):
page = document.load_page(page_num) # loads the specified page
text += page.get_text() # type: ignore

return text


def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None:
Expand Down
7 changes: 4 additions & 3 deletions baselines/run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,18 @@ def run_agent_for_repo(
if latest_commit.hexsha != example["base_commit"]:
local_repo.git.reset("--hard", example["base_commit"])
target_edit_files = get_target_edit_files(repo_path)

with DirContext(repo_path):
if commit0_config is None or agent_config is None:
raise ValueError("Invalid input")

message = get_message(agent_config, repo_path, example["test"]["test_dir"])
if agent_config.run_tests:
# when unit test feedback is available, iterate over test files
for test_file in test_files:
test_cmd = f"python -m commit0 test {repo_path} {run_id} {test_file}"
test_file_name = test_file.replace(".py", "").replace("/", "__")
log_dir = RUN_AIDER_LOG_DIR / "with_tests" / test_file_name
lint_cmd = get_lint_cmd(local_repo, agent_config.use_lint_info)

message = get_message(agent_config, repo_path, test_file=test_file)
agent.run(
message,
test_cmd,
Expand All @@ -104,6 +102,9 @@ def run_agent_for_repo(
)
else:
# when unit test feedback is not available, iterate over target files to edit
message = get_message(
agent_config, repo_path, test_dir=example["test"]["test_dir"]
)
for f in target_edit_files:
file_name = f.replace(".py", "").replace("/", "__")
log_dir = RUN_AIDER_LOG_DIR / "no_tests" / file_name
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ dependencies = [
"ruff>=0.6.4",
"pre-commit>=3.8.0",
"hydra-core>=1.3.2",
"PyMuPDF>=1.24.5",
"aider-chat>=0.56.0",
"modal>=0.64.95",
"typer>=0.12.0",
"aider-chat",
"datasets>=3.0.0",
"docker>=7.1.0",
"fastcore>=1.7.8",
Expand Down